Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
67ea635f
Commit
67ea635f
authored
Mar 30, 2023
by
aiss
Browse files
push dsv0.8.2 version
parent
1b2721ad
Pipeline
#201
failed with stages
in 0 seconds
Changes
339
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
0 additions
and
8313 deletions
+0
-8313
deepspeed/ops/csrc/includes/strided_batch_gemm.h
deepspeed/ops/csrc/includes/strided_batch_gemm.h
+0
-195
deepspeed/ops/csrc/includes/strided_batch_gemm_hip.h
deepspeed/ops/csrc/includes/strided_batch_gemm_hip.h
+0
-196
deepspeed/ops/csrc/includes/type_shim.h
deepspeed/ops/csrc/includes/type_shim.h
+0
-119
deepspeed/ops/csrc/includes/type_shim_hip.h
deepspeed/ops/csrc/includes/type_shim_hip.h
+0
-121
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
+0
-109
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
+0
-474
deepspeed/ops/csrc/lamb/fused_lamb_hip_kernel.hip
deepspeed/ops/csrc/lamb/fused_lamb_hip_kernel.hip
+0
-475
deepspeed/ops/csrc/quantization/pt_binding.cpp
deepspeed/ops/csrc/quantization/pt_binding.cpp
+0
-77
deepspeed/ops/csrc/quantization/pt_binding_hip.cpp
deepspeed/ops/csrc/quantization/pt_binding_hip.cpp
+0
-78
deepspeed/ops/csrc/quantization/quantizer.hip
deepspeed/ops/csrc/quantization/quantizer.hip
+0
-1039
deepspeed/ops/csrc/sparse_attention/utils.cpp
deepspeed/ops/csrc/sparse_attention/utils.cpp
+0
-120
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
+0
-403
deepspeed/ops/csrc/transformer/cublas_wrappers.hip
deepspeed/ops/csrc/transformer/cublas_wrappers.hip
+0
-404
deepspeed/ops/csrc/transformer/dropout_kernels.cu
deepspeed/ops/csrc/transformer/dropout_kernels.cu
+0
-868
deepspeed/ops/csrc/transformer/dropout_kernels.hip
deepspeed/ops/csrc/transformer/dropout_kernels.hip
+0
-870
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
+0
-1051
deepspeed/ops/csrc/transformer/ds_transformer_hip.cpp
deepspeed/ops/csrc/transformer/ds_transformer_hip.cpp
+0
-1052
deepspeed/ops/csrc/transformer/gelu_kernels.cu
deepspeed/ops/csrc/transformer/gelu_kernels.cu
+0
-330
deepspeed/ops/csrc/transformer/gelu_kernels.hip
deepspeed/ops/csrc/transformer/gelu_kernels.hip
+0
-332
No files found.
Too many changes to show.
To preserve performance only
339 of 339+
files are displayed.
Plain diff
Email patch
deepspeed/ops/csrc/includes/strided_batch_gemm.h
deleted
100644 → 0
View file @
1b2721ad
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "context.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
cublasOperation_t
op_A
;
cublasOperation_t
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
cublasOperation_t
opA
,
cublasOperation_t
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
#endif
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
#endif
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
cublasHandle_t
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
cublasOperation_t
op_b
=
(
_config
.
op_B
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
CUBLAS_OP_T
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
CUBLAS_OP_T
?
d_output
:
_buffer_b
),
inpGradA
,
CUBLAS_OP_N
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
1
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
#endif
// A need to transpose.
cublasOperation_t
op_a
=
(
_config
.
op_A
==
CUBLAS_OP_T
?
CUBLAS_OP_N
:
CUBLAS_OP_T
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
CUBLAS_OP_N
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
2
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
#endif
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
deepspeed/ops/csrc/includes/strided_batch_gemm_hip.h
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "context_hip.h"
template
<
typename
T
>
class
StridedBatchGemm
{
public:
struct
Config
{
int
batch_size
;
int
m
;
int
n
;
int
k
;
float
alpha
;
float
beta
;
rocblas_operation
op_A
;
rocblas_operation
op_B
;
std
::
array
<
int
,
3
>
gemm_algos
;
Config
(
int
batch
,
int
mm
,
int
nn
,
int
kk
,
float
param_alpha
,
float
param_beta
,
rocblas_operation
opA
,
rocblas_operation
opB
,
const
std
::
array
<
int
,
3
>&
algos
)
:
batch_size
(
batch
),
m
(
mm
),
n
(
nn
),
k
(
kk
),
alpha
(
param_alpha
),
beta
(
param_beta
),
op_A
(
opA
),
op_B
(
opB
),
gemm_algos
(
algos
)
{
}
void
SetConfig
(
int
mm
,
int
nn
,
int
kk
)
{
m
=
mm
;
n
=
nn
;
k
=
kk
;
}
};
StridedBatchGemm
(
const
Config
&
config
)
:
_config
(
config
)
{}
virtual
~
StridedBatchGemm
()
{}
void
Forward
(
int
bsz
,
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
#endif
}
void
ForwardPlusSave
(
T
*
output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
)
{
int
stride_a
=
_config
.
m
*
_config
.
k
;
int
stride_b
=
_config
.
n
*
_config
.
k
;
int
stride_c
=
_config
.
m
*
_config
.
n
;
cublas_strided_batched_gemm
(
handle
,
_config
.
m
,
_config
.
n
,
_config
.
k
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
_buffer_b
,
output
,
_config
.
op_A
,
_config
.
op_B
,
stride_a
,
stride_b
,
stride_c
,
_config
.
batch_size
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
0
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
0
]));
#endif
k_buf
=
_buffer_a
;
q_buf
=
_buffer_b
;
}
void
Backward
(
int
bsz
,
const
T
*
d_output
,
const
T
*
_buffer_a
,
const
T
*
_buffer_b
,
rocblas_handle
handle
,
T
*
inpGradA
=
nullptr
,
T
*
inpGradB
=
nullptr
)
{
int
mb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
k
:
_config
.
m
);
int
kb
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
_config
.
m
:
_config
.
k
);
int
stride_a
=
mb
*
_config
.
n
;
int
stride_b
=
_config
.
n
*
kb
;
int
stride_c
=
_config
.
m
*
_config
.
k
;
// B need to transpose.
rocblas_operation
op_b
=
(
_config
.
op_B
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
// Calculate d_A.
cublas_strided_batched_gemm
(
handle
,
mb
,
kb
,
_config
.
n
,
&
_config
.
alpha
,
&
_config
.
beta
,
(
_config
.
op_A
==
rocblas_operation_transpose
?
_buffer_b
:
d_output
),
(
_config
.
op_A
==
rocblas_operation_transpose
?
d_output
:
_buffer_b
),
inpGradA
,
rocblas_operation_none
,
op_b
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
1
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
1
]));
#endif
// A need to transpose.
rocblas_operation
op_a
=
(
_config
.
op_A
==
rocblas_operation_transpose
?
rocblas_operation_none
:
rocblas_operation_transpose
);
stride_a
=
_config
.
m
*
_config
.
k
;
stride_b
=
_config
.
m
*
_config
.
n
;
stride_c
=
_config
.
n
*
_config
.
k
;
// Calculate d_B.
cublas_strided_batched_gemm
(
handle
,
_config
.
k
,
_config
.
n
,
_config
.
m
,
&
_config
.
alpha
,
&
_config
.
beta
,
_buffer_a
,
d_output
,
inpGradB
,
op_a
,
rocblas_operation_none
,
stride_a
,
stride_b
,
stride_c
,
bsz
,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo
(
_config
.
gemm_algos
[
2
]));
#else
cublasGemmAlgo_t
(
_config
.
gemm_algos
[
2
]));
#endif
}
inline
int
GetN
()
const
{
return
_config
.
k
;
}
inline
const
T
*
GetBufferA
()
const
{
return
k_buf
;
}
inline
const
T
*
GetBufferB
()
const
{
return
q_buf
;
}
inline
void
SetConfig
(
int
m
,
int
n
,
int
k
)
{
_config
.
SetConfig
(
m
,
n
,
k
);
}
private:
Config
_config
;
const
T
*
q_buf
;
const
T
*
k_buf
;
};
deepspeed/ops/csrc/includes/type_shim.h
deleted
100644 → 0
View file @
1b2721ad
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatibility hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
deepspeed/ops/csrc/includes/type_shim_hip.h
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatibility hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
deepspeed/ops/csrc/lamb/fused_lamb_cuda.cpp
deleted
100644 → 0
View file @
1b2721ad
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <torch/extension.h>
// CUDA forward declaration
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff_val
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// C++ interface
at
::
Tensor
lamb
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
// intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behaviour is unexpected
at
::
Tensor
w_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
// intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behaviour is unexpected
at
::
Tensor
u_l2_i
=
at
::
empty
(
{
512
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
at
::
Tensor
lamb_coeff_val
=
at
::
empty
(
{
1
},
p
.
options
().
dtype
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
p
.
type
().
scalarType
()));
fused_lamb_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
,
w_l2_i
,
u_l2_i
,
lamb_coeff_val
);
return
lamb_coeff_val
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"lamb"
,
&
lamb
,
"Adam optimized CUDA implementation with LAMB."
);
}
deepspeed/ops/csrc/lamb/fused_lamb_cuda_kernel.cu
deleted
100644 → 0
View file @
1b2721ad
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <cuda_runtime_api.h>
#include <stdio.h>
namespace
cg
=
cooperative_groups
;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template
<
typename
T
>
struct
SharedMemory
{
// Ensure that we won't compile any un-specialized types
__device__
inline
operator
T
*
()
{
#ifndef _WIN32
extern
__device__
void
error
(
void
);
error
();
#endif
return
NULL
;
}
};
template
<
>
struct
SharedMemory
<
float
>
{
__device__
inline
operator
float
*
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
inline
operator
double
*
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
// namespace
#include "type_shim.h"
typedef
enum
{
ADAM_MODE_0
=
0
,
// eps under square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template
<
typename
T
,
int
blockSize
>
__device__
void
reduce_block_in_shared_memory
(
T
*
s_a
,
T
*
s_b
,
T
*
g_a
,
T
*
g_b
)
{
// Handle to thread block group
cg
::
thread_block
cta
=
cg
::
this_thread_block
();
// perform block reduction in shared memory,
unsigned
int
tid
=
cta
.
thread_rank
();
T
a_sum
=
s_a
[
tid
];
T
b_sum
=
s_b
[
tid
];
cg
::
sync
(
cta
);
// do reduction in shared mem
if
((
blockSize
>=
512
)
&&
(
tid
<
256
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
256
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
256
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
256
)
&&
(
tid
<
128
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
128
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
128
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
128
)
&&
(
tid
<
64
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
64
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
64
];
}
cg
::
sync
(
cta
);
#if (__CUDA_ARCH__ >= 300)
if
(
tid
<
32
)
{
cg
::
coalesced_group
active
=
cg
::
coalesced_threads
();
// Fetch final intermediate sum from 2nd warp
if
(
blockSize
>=
64
)
{
a_sum
=
a_sum
+
s_a
[
tid
+
32
];
b_sum
=
b_sum
+
s_b
[
tid
+
32
];
}
// Reduce final warp using shuffle
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
{
a_sum
+=
active
.
shfl_down
(
a_sum
,
offset
);
b_sum
+=
active
.
shfl_down
(
b_sum
,
offset
);
}
}
#else
if
((
blockSize
>=
64
)
&&
(
tid
<
32
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
32
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
32
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
32
)
&&
(
tid
<
16
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
16
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
16
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
16
)
&&
(
tid
<
8
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
8
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
8
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
8
)
&&
(
tid
<
4
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
4
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
4
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
4
)
&&
(
tid
<
2
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
2
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
2
];
}
cg
::
sync
(
cta
);
if
((
blockSize
>=
2
)
&&
(
tid
<
1
))
{
s_a
[
tid
]
=
a_sum
=
a_sum
+
s_a
[
tid
+
1
];
s_b
[
tid
]
=
b_sum
=
b_sum
+
s_b
[
tid
+
1
];
}
cg
::
sync
(
cta
);
#endif
// write result for this block to global mem
if
(
tid
==
0
)
{
g_a
[
blockIdx
.
x
]
=
(
T
)
a_sum
;
g_b
[
blockIdx
.
x
]
=
(
T
)
b_sum
;
}
}
template
<
typename
T
,
int
blockSize
>
__device__
void
reduce_two_vectors_in_register
(
T
a
,
T
b
,
T
*
g_a
,
T
*
g_b
)
{
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
T
*
s_a
=
SharedMemory
<
T
>
();
T
*
s_b
=
SharedMemory
<
T
>
()
+
cg
::
this_thread_block
().
size
();
s_a
[
threadIdInBlock
]
=
a
;
s_b
[
threadIdInBlock
]
=
b
;
reduce_block_in_shared_memory
<
T
,
blockSize
>
(
s_a
,
s_b
,
g_a
,
g_b
);
}
template
<
typename
T
,
typename
GRAD_T
,
int
blockSize
>
__global__
void
lamb_cuda_kernel_part1
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
,
T
*
__restrict__
w_l2_i
,
T
*
__restrict__
u_l2_i
)
{
// Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
reg_w
=
0
;
T
reg_u
=
0
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
T
pj
=
p
[
j
];
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
T
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
reg_u
+=
update
*
update
;
reg_w
+=
pj
*
pj
;
}
reduce_two_vectors_in_register
<
T
,
blockSize
>
(
reg_w
,
reg_u
,
w_l2_i
,
u_l2_i
);
}
template
<
typename
T
,
typename
GRAD_T
,
int
blockSize
>
__global__
void
lamb_cuda_kernel_part2
(
const
size_t
tsize
,
T
*
__restrict__
g_a
,
T
*
__restrict__
g_b
)
{
T
*
s_a
=
SharedMemory
<
T
>
();
T
*
s_b
=
SharedMemory
<
T
>
()
+
cg
::
this_thread_block
().
size
();
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
s_a
[
threadIdInBlock
]
=
g_a
[
threadIdInBlock
];
s_b
[
threadIdInBlock
]
=
g_b
[
threadIdInBlock
];
if
(
threadIdInBlock
>=
tsize
)
{
s_a
[
threadIdInBlock
]
=
0.0
;
s_b
[
threadIdInBlock
]
=
0.0
;
}
reduce_block_in_shared_memory
<
T
,
blockSize
>
(
s_a
,
s_b
,
g_a
,
g_b
);
}
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
lamb_cuda_kernel_part3
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
max_coeff
,
const
float
min_coeff
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
,
T
*
__restrict__
w_l2_i
,
T
*
__restrict__
u_l2_i
,
T
*
__restrict__
lamb_coeff_val
)
{
// Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
cg
::
this_thread_block
().
thread_rank
();
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
T
reg_w
=
sqrtf
(
w_l2_i
[
0
]);
T
reg_u
=
sqrtf
(
u_l2_i
[
0
]);
float
lamb_coeff
=
1.0
;
if
(
reg_w
!=
0
&&
reg_u
!=
0
)
{
lamb_coeff
=
reg_w
/
reg_u
;
if
(
lamb_coeff
>
max_coeff
)
{
lamb_coeff
=
max_coeff
;
}
if
(
lamb_coeff
<
min_coeff
)
{
lamb_coeff
=
min_coeff
;
}
}
if
(
blockId
==
0
&&
threadIdInBlock
==
0
)
{
lamb_coeff_val
[
0
]
=
lamb_coeff
;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
pj
=
(
float
)
p
[
j
];
T
mj
=
m
[
j
];
T
vj
=
v
[
j
];
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
vj
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
vj
)
+
eps
;
T
update
=
(
mj
/
denom
)
+
(
decay
*
pj
);
pj
=
pj
-
(
step_size
*
lamb_coeff
*
update
);
p
[
j
]
=
pj
;
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
pj
;
}
}
void
fused_lamb_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
max_coeff
,
float
min_coeff
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
,
at
::
Tensor
&
w_l2_i
,
at
::
Tensor
&
u_l2_i
,
at
::
Tensor
&
lamb_coeff
)
{
// using namespace at;
// Get tensor size
int
tsize
=
p
.
numel
();
// Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
int
num_blocks
=
(
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
;
if
(
num_blocks
>
512
)
num_blocks
=
512
;
int
smemsize
=
0
;
if
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Double
)
smemsize
=
2
*
threadsPerBlock
*
sizeof
(
double
);
else
smemsize
=
2
*
threadsPerBlock
*
sizeof
(
float
);
const
dim3
blocks
(
num_blocks
);
const
dim3
threads
(
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
// Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
)
{
// all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
// dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
g
.
scalar_type
(),
"lamb_cuda_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
lamb_cuda_kernel_part1
<
accscalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
>
()
:
NULL
,
m
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
());
lamb_cuda_kernel_part2
<
accscalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
1
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
num_blocks
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
());
lamb_cuda_kernel_part3
<
accscalar_t
,
scalar_t
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
>
()
:
NULL
,
m
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
accscalar_t
>
(),
u_l2_i
.
data
<
accscalar_t
>
(),
lamb_coeff
.
data
<
accscalar_t
>
());
}));
}
else
{
using
namespace
at
;
AT_DISPATCH_FLOATING_TYPES
(
g
.
scalar_type
(),
"lamb_cuda_kernel"
,
([
&
]
{
lamb_cuda_kernel_part1
<
scalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
NULL
,
// don't output p_copy for fp32, it's wasted write
m
.
data
<
scalar_t
>
(),
v
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
());
lamb_cuda_kernel_part2
<
scalar_t
,
scalar_t
,
threadsPerBlock
>
<<<
1
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
num_blocks
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
());
lamb_cuda_kernel_part3
<
scalar_t
,
scalar_t
>
<<<
blocks
,
threadsPerBlock
,
smemsize
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
NULL
,
// don't output p_copy for fp32, it's wasted write
m
.
data
<
scalar_t
>
(),
v
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
>
(),
beta1
,
beta2
,
max_coeff
,
min_coeff
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
,
w_l2_i
.
data
<
scalar_t
>
(),
u_l2_i
.
data
<
scalar_t
>
(),
lamb_coeff
.
data
<
scalar_t
>
());
}));
}
C10_CUDA_CHECK
(
cudaGetLastError
());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
deepspeed/ops/csrc/lamb/fused_lamb_hip_kernel.hip
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
#ifndef _WIN32
extern __device__ void error(void);
error();
#endif
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "type_shim_hip.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 && reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 && threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
C10_HIP_CHECK(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
deepspeed/ops/csrc/quantization/pt_binding.cpp
deleted
100644 → 0
View file @
1b2721ad
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_cuda_layers.h"
template
<
typename
T
>
at
::
Tensor
ds_quantize
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
((((
size
/
groups
)
-
1
)
/
4096
+
1
)
<=
MAX_REG
)
{
launch_quantize_kernel
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_sr_quantize
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
(((
size
/
groups
)
/
4
/
1024
)
<=
256
)
{
launch_sr_quantize_kernel
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_quantize_asym
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
((((
size
/
groups
)
-
1
)
/
4096
+
1
)
<=
MAX_REG
)
{
launch_quantize_kernel_asym
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_sr_quantize_asym
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
(((
size
/
groups
)
/
4
/
1024
)
<=
256
)
{
launch_sr_quantize_kernel_asym
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
vals
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ds_quantize_fp32"
,
&
ds_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_quantize_fp16"
,
&
ds_quantize
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_fp32"
,
&
ds_sr_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_fp16"
,
&
ds_sr_quantize
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_quantize_asym_fp32"
,
&
ds_quantize_asym
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_quantize_asym_fp16"
,
&
ds_quantize_asym
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_asym_fp32"
,
&
ds_sr_quantize_asym
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_asym_fp16"
,
&
ds_sr_quantize_asym
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
}
deepspeed/ops/csrc/quantization/pt_binding_hip.cpp
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_hip_layers.h"
template
<
typename
T
>
at
::
Tensor
ds_quantize
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
((((
size
/
groups
)
-
1
)
/
4096
+
1
)
<=
MAX_REG
)
{
launch_quantize_kernel
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_sr_quantize
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
(((
size
/
groups
)
/
4
/
1024
)
<=
256
)
{
launch_sr_quantize_kernel
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_quantize_asym
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
((((
size
/
groups
)
-
1
)
/
4096
+
1
)
<=
MAX_REG
)
{
launch_quantize_kernel_asym
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
template
<
typename
T
>
at
::
Tensor
ds_sr_quantize_asym
(
at
::
Tensor
&
vals
,
int
groups
,
int
bits
)
{
auto
t_size
=
vals
.
sizes
();
int
size
=
1
;
for
(
auto
dim
:
t_size
)
size
*=
dim
;
if
(((
size
/
groups
)
/
4
/
1024
)
<=
256
)
{
launch_sr_quantize_kernel_asym
(
(
T
*
)
vals
.
data_ptr
(),
size
,
groups
,
bits
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
vals
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ds_quantize_fp32"
,
&
ds_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_quantize_fp16"
,
&
ds_quantize
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_fp32"
,
&
ds_sr_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_fp16"
,
&
ds_sr_quantize
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_quantize_asym_fp32"
,
&
ds_quantize_asym
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_quantize_asym_fp16"
,
&
ds_quantize_asym
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_asym_fp32"
,
&
ds_sr_quantize_asym
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_sr_quantize_asym_fp16"
,
&
ds_sr_quantize_asym
<
__half
>
,
"DeepSpeed Quantize with fp16 (CUDA)"
);
}
deepspeed/ops/csrc/quantization/quantizer.hip
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <math.h>
#include "custom_hip_layers.h"
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
b.sync();
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
q_data_int[0].x =
(rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
: q_data_int[0].x;
q_data_int[0].y =
(rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
: q_data_int[0].y;
q_data_int[1].x =
(rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
: q_data_int[1].x;
q_data_int[1].y =
(rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x / q_scale_val;
data_f[0].y = q_data_int[0].y / q_scale_val;
data_f[1].x = q_data_int[1].x / q_scale_val;
data_f[1].y = q_data_int[1].y / q_scale_val;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
data[reg_count] = vals_cast[group_index];
if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)(q_data.x * q_scale_val));
q_data_int.y = (float)((int)(q_data.y * q_scale_val));
q_data_int.w = (float)((int)(q_data.w * q_scale_val));
q_data_int.z = (float)((int)(q_data.z * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
q_data_int.x =
(rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
: q_data_int.x;
q_data_int.y =
(rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
: q_data_int.y;
q_data_int.w =
(rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
: q_data_int.w;
q_data_int.z =
(rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
: q_data_int.z;
q_data_int.x /= q_scale_val;
q_data_int.y /= q_scale_val;
q_data_int.w /= q_scale_val;
q_data_int.z /= q_scale_val;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
float min = 10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (((float)data_h[0]) > max) max = (float)data_h[0];
if (((float)data_h[1]) > max) max = (float)data_h[1];
if (((float)data_h[2]) > max) max = (float)data_h[2];
if (((float)data_h[3]) > max) max = (float)data_h[3];
if (((float)data_h[0]) < min) min = (float)data_h[0];
if (((float)data_h[1]) < min) min = (float)data_h[1];
if (((float)data_h[2]) < min) min = (float)data_h[2];
if (((float)data_h[3]) < min) min = (float)data_h[3];
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
q_data_int[0].x = q_data_int[0].x * q_scale + min;
q_data_int[0].y = q_data_int[0].y * q_scale + min;
q_data_int[1].x = q_data_int[1].x * q_scale + min;
q_data_int[1].y = q_data_int[1].y * q_scale + min;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
float min = 10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
q_data.x = q_data_int.x * q_scale + min;
q_data.y = q_data_int.y * q_scale + min;
q_data.w = q_data_int.w * q_scale + min;
q_data.z = q_data_int.z * q_scale + min;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel_asym), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_val_inv = 1 / q_scale_val;
float high_q = (float)((1 << num_bits) - 1);
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] =
abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[1] =
abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
q_error[2] =
abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[3] =
abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
? (q_data_int[0].x + 1)
: q_data_int[0].x;
q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
? (q_data_int[0].y + 1)
: q_data_int[0].y;
q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
? (q_data_int[1].x + 1)
: q_data_int[1].x;
q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
? (q_data_int[1].y + 1)
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x * q_scale_val + min;
data_f[0].y = q_data_int[0].y * q_scale_val + min;
data_f[1].x = q_data_int[1].x * q_scale_val + min;
data_f[1].y = q_data_int[1].y * q_scale_val + min;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float high_q = (float)((1 << num_bits) - 1);
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
: q_data_int.x;
q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
: q_data_int.y;
q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
: q_data_int.w;
q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
: q_data_int.z;
q_data_int.x = q_data_int.x * q_scale_val + min;
q_data_int.y = q_data_int.y * q_scale_val + min;
q_data_int.w = q_data_int.w * q_scale_val + min;
q_data_int.z = q_data_int.z * q_scale_val + min;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
deepspeed/ops/csrc/sparse_attention/utils.cpp
deleted
100644 → 0
View file @
1b2721ad
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
#include <torch/extension.h>
#include <string>
#include <tuple>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
#endif
typedef
std
::
vector
<
std
::
tuple
<
int
,
torch
::
Tensor
>>
ret_t
;
void
segment_blocks
(
torch
::
Tensor
layout
,
torch
::
Tensor
idx
,
torch
::
Tensor
scratch
,
int
max_width
,
ret_t
&
ret
)
{
size_t
H
=
layout
.
size
(
0
);
size_t
M
=
layout
.
size
(
1
);
size_t
N
=
layout
.
size
(
2
);
torch
::
Tensor
tmp
=
torch
::
zeros_like
(
layout
);
auto
_tmp
=
tmp
.
accessor
<
int
,
3
>
();
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
auto
_scratch
=
scratch
.
accessor
<
int
,
3
>
();
std
::
vector
<
int
>
current
(
H
,
0
);
#ifdef _OPENMP
#pragma omp parallel for
#endif
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
{
// surrounding indices
std
::
vector
<
int
>
ii_left
(
max_width
,
-
1
);
std
::
vector
<
std
::
vector
<
int
>>
ii_top
(
max_width
,
std
::
vector
<
int
>
(
N
,
-
1
));
for
(
size_t
m
=
0
;
m
<
M
;
m
++
)
{
for
(
size_t
n
=
0
;
n
<
N
;
n
++
)
{
int
v
=
_layout
[
h
][
m
][
n
];
if
(
v
==
0
)
continue
;
int
n_left
=
ii_left
[
max_width
-
1
];
int
m_top
=
ii_top
[
max_width
-
1
][
n
];
int
top
=
(
m_top
>=
0
)
?
_tmp
[
h
][
m_top
][
n
]
:
0
;
int
left
=
(
n_left
>=
0
)
?
_tmp
[
h
][
m
][
n_left
]
:
0
;
int
topleft
=
(
m_top
>=
0
&&
n_left
>=
0
)
?
_tmp
[
h
][
m_top
][
n_left
]
:
0
;
int
width
=
std
::
min
(
left
,
std
::
min
(
top
,
topleft
))
+
1
;
// reset width if blocks cannot be
// packed together (i.e., there's a 1 "in the middle")
for
(
int
nn
=
n_left
+
1
;
nn
<
n
;
nn
++
)
if
(
ii_top
[
max_width
-
1
][
nn
]
>
ii_top
[
max_width
-
1
][
n
])
width
=
1
;
_tmp
[
h
][
m
][
n
]
=
width
;
// update n_left ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_left
[
k
]
=
ii_left
[
k
+
1
];
ii_left
[
max_width
-
1
]
=
n
;
// update ii_top ring buffer
for
(
int
k
=
0
;
k
<
max_width
-
1
;
k
++
)
ii_top
[
k
][
n
]
=
ii_top
[
k
+
1
][
n
];
ii_top
[
max_width
-
1
][
n
]
=
m
;
// block is too small -- skip
if
(
width
!=
max_width
)
continue
;
// retained blocks are set to zeros
for
(
size_t
km
=
0
;
km
<
max_width
;
km
++
)
for
(
size_t
kn
=
0
;
kn
<
max_width
;
kn
++
)
{
int
mm
=
ii_top
[
km
][
n
];
int
nn
=
ii_left
[
kn
];
if
(
mm
<
0
||
nn
<
0
)
continue
;
_layout
[
h
][
mm
][
nn
]
=
0
;
_tmp
[
h
][
mm
][
nn
]
=
0
;
_scratch
[
h
][
current
[
h
]][
0
]
=
(
int
)
h
;
_scratch
[
h
][
current
[
h
]][
1
]
=
(
int
)
mm
;
_scratch
[
h
][
current
[
h
]][
2
]
=
(
int
)
nn
;
_scratch
[
h
][
current
[
h
]][
3
]
=
_idx
[
h
][
mm
][
nn
];
current
[
h
]
++
;
}
}
}
}
std
::
vector
<
torch
::
Tensor
>
to_cat
;
for
(
size_t
h
=
0
;
h
<
H
;
h
++
)
if
(
current
[
h
]
>
0
)
to_cat
.
push_back
(
scratch
[
h
].
slice
(
0
,
0
,
current
[
h
]));
if
(
!
to_cat
.
empty
())
ret
.
push_back
({
max_width
,
torch
::
cat
(
to_cat
)});
}
ret_t
sdd_segment
(
torch
::
Tensor
layout
,
int
start_width
)
{
ret_t
ret
;
// block index
torch
::
Tensor
idx
=
torch
::
zeros_like
(
layout
);
int
current
=
0
;
int64_t
H
=
layout
.
size
(
0
);
int64_t
M
=
layout
.
size
(
1
);
int64_t
N
=
layout
.
size
(
2
);
auto
_layout
=
layout
.
accessor
<
int
,
3
>
();
auto
_idx
=
idx
.
accessor
<
int
,
3
>
();
for
(
int64_t
h
=
0
;
h
<
H
;
h
++
)
for
(
int64_t
m
=
0
;
m
<
M
;
m
++
)
for
(
int64_t
n
=
0
;
n
<
N
;
n
++
)
{
if
(
_layout
[
h
][
m
][
n
]
==
0
)
continue
;
_idx
[
h
][
m
][
n
]
=
current
++
;
}
// scratch memory
torch
::
Tensor
scratch
=
torch
::
empty
({
H
,
layout
.
sum
().
item
<
int
>
(),
4
},
layout
.
dtype
());
for
(
int
max_width
=
start_width
;
max_width
>
0
;
max_width
/=
2
)
segment_blocks
(
layout
,
idx
,
scratch
,
max_width
,
ret
);
return
ret
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"sdd_segment"
,
&
sdd_segment
,
"SDD segmentation handler"
);
}
deepspeed/ops/csrc/transformer/cublas_wrappers.cu
deleted
100644 → 0
View file @
1b2721ad
#include "cublas_wrappers.h"
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f32_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f32_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
C
,
rocblas_datatype_f32_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_32F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_32F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
CUDA_R_32F
,
m
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f16_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f16_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_16F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f32_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_32F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_32F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f16_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_16F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_16F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
deepspeed/ops/csrc/transformer/cublas_wrappers.hip
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#include "cublas_wrappers_hip.h"
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f32_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f32_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
rocblas_datatype_f32_r,
m,
C,
rocblas_datatype_f32_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f16_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f16_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
rocblas_datatype_f16_r,
m,
(void*)C,
rocblas_datatype_f16_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f32_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f32_r,
m,
stride_C,
C,
rocblas_datatype_f32_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f16_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f16_r,
m,
stride_C,
C,
rocblas_datatype_f16_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
deepspeed/ops/csrc/transformer/dropout_kernels.cu
deleted
100644 → 0
View file @
1b2721ad
#include "custom_cuda_layers.h"
const
int
unroll_factor
=
4
;
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
float
*
out
,
const
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint8_t
m
[
unroll_factor
];
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
int
i
=
j
*
unroll_factor
;
mask
[
i
]
=
(
uint8_t
)
m
[
0
];
mask
[
i
+
1
]
=
(
uint8_t
)
m
[
1
];
mask
[
i
+
2
]
=
(
uint8_t
)
m
[
2
];
mask
[
i
+
3
]
=
(
uint8_t
)
m
[
3
];
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
[
0
];
out
[
i
+
1
]
=
Xdata
[
i
+
1
]
*
scale
*
m
[
1
];
out
[
i
+
2
]
=
Xdata
[
i
+
2
]
*
scale
*
m
[
2
];
out
[
i
+
3
]
=
Xdata
[
i
+
3
]
*
scale
*
m
[
3
];
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
__half
*
out
,
const
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
uint32_t
m_32
;
uint8_t
*
m
=
reinterpret_cast
<
uint8_t
*>
(
&
m_32
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
mask_cast
[
j
]
=
m_32
;
}
#else
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
uint8_t
m
[
unroll_factor
];
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
mask
[
i
]
=
m
[
0
];
mask
[
i
+
1
]
=
m
[
1
];
mask
[
i
+
2
]
=
m
[
2
];
mask
[
i
+
3
]
=
m
[
3
];
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
out
[
i
+
1
]
=
mask
[
i
+
1
]
?
Xdata
[
i
+
1
]
*
scale
:
0.0
;
out
[
i
+
2
]
=
mask
[
i
+
2
]
?
Xdata
[
i
+
2
]
*
scale
:
0.0
;
out
[
i
+
3
]
=
mask
[
i
+
3
]
?
Xdata
[
i
+
3
]
*
scale
:
0.0
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
mask_h
[
i
]
=
__float22half2_rn
(
mask_f
[
i
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
}
#else
const
__half
h_scale
=
__float2half
(
scale
);
const
__half
h_zero
=
__float2half
(
0.0
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
uint8_t
*
m
=
mask
+
i
;
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
bwd
)
{
assert
(
unroll_factor
==
4
);
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
/
unroll_factor
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
if
(
dim
>
512
)
{
block_dim
.
x
>>=
1
;
grid_dim
.
x
<<=
1
;
}
uint64_t
inc
=
total_count
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
if
(
bwd
)
dropout_kernel_bwd
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
vals
,
out
,
mask
,
seed
);
else
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
template
void
launch_dropout
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
float
*
Xdata
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
Xdata
[
i
]
*=
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
__half
*
Xdata
,
uint8_t
*
mask
)
{
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
float2
*
x_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
#ifdef __STOCHASTIC_MODE__
__half2
*
x_data_h
=
reinterpret_cast
<
__half2
*>
(
&
x_data
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
*
(
mask_f_data
++
)
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_data_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_data_h
[
1
]
*
h_scale
*
mask_h
[
1
];
#else
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
#endif
x_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
Xdata
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
out
[
i
]
=
Xdata
[
i
]
*
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
)
{
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
const
uint32_t
*
mask_cast
=
reinterpret_cast
<
const
uint32_t
*>
(
mask
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
out_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals_out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
vals_out
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
bias
,
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
Xdata_cast
=
reinterpret_cast
<
float4
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
x_data
=
Xdata_cast
[
j
];
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
x_data
.
x
+=
b_data
.
x
;
x_data
.
y
+=
b_data
.
y
;
x_data
.
z
+=
b_data
.
z
;
x_data
.
w
+=
b_data
.
w
;
x_data
.
x
=
x_data
.
x
*
scale
*
m
[
0
];
x_data
.
y
=
x_data
.
y
*
scale
*
m
[
1
];
x_data
.
z
=
x_data
.
z
*
scale
*
m
[
2
];
x_data
.
w
=
x_data
.
w
*
scale
*
m
[
3
];
mask_32
[
j
]
=
m_32
;
Xdata_cast
[
j
]
=
x_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
Xdata
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
x_data
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
bias
,
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
Xdata_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
data_f
=
Xdata_cast
[
j
];
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
data_h_0
.
x
+=
bias_h_0
.
x
;
data_h_0
.
y
+=
bias_h_0
.
y
;
data_h_1
.
x
+=
bias_h_1
.
x
;
data_h_1
.
y
+=
bias_h_1
.
y
;
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
Xdata_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
Xdata
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
__float2half
(
x_data
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
input
,
const
float
*
residual
,
const
float
*
bias
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
out_cast
=
reinterpret_cast
<
float4
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
const
float4
*
residual_cast
=
reinterpret_cast
<
const
float4
*>
(
residual
);
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
out_data
;
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float4
res_data
=
residual_cast
[
j
];
float4
inp_data
=
input_cast
[
j
];
out_data
.
x
=
(
b_data
.
x
+
inp_data
.
x
);
out_data
.
y
=
(
b_data
.
y
+
inp_data
.
y
);
out_data
.
z
=
(
b_data
.
z
+
inp_data
.
z
);
out_data
.
w
=
(
b_data
.
w
+
inp_data
.
w
);
out_data
.
x
=
out_data
.
x
*
scale
*
m
[
0
];
out_data
.
y
=
out_data
.
y
*
scale
*
m
[
1
];
out_data
.
z
=
out_data
.
z
*
scale
*
m
[
2
];
out_data
.
w
=
out_data
.
w
*
scale
*
m
[
3
];
out_data
.
x
+=
res_data
.
x
;
out_data
.
y
+=
res_data
.
y
;
out_data
.
z
+=
res_data
.
z
;
out_data
.
w
+=
res_data
.
w
;
mask_32
[
j
]
=
m_32
;
out_cast
[
j
]
=
out_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
input
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
residual
[
i
];
out
[
i
]
=
x_data
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
input
,
const
__half
*
residual
,
const
__half
*
bias
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
const
float2
*
residual_cast
=
reinterpret_cast
<
const
float2
*>
(
residual
);
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
float2
residual_f
;
__half2
*
residual_h
=
reinterpret_cast
<
__half2
*>
(
&
residual_f
);
float2
input_f
;
__half2
*
input_h
=
reinterpret_cast
<
__half2
*>
(
&
input_f
);
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
residual_f
=
residual_cast
[
j
];
input_f
=
input_cast
[
j
];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
float2
residual_h_0
=
__half22float2
(
residual_h
[
0
]);
float2
residual_h_1
=
__half22float2
(
residual_h
[
1
]);
float2
input_h_0
=
__half22float2
(
input_h
[
0
]);
float2
input_h_1
=
__half22float2
(
input_h
[
1
]);
data_h_0
.
x
=
(
bias_h_0
.
x
+
input_h_0
.
x
);
data_h_0
.
y
=
(
bias_h_0
.
y
+
input_h_0
.
y
);
data_h_1
.
x
=
(
bias_h_1
.
x
+
input_h_1
.
x
);
data_h_1
.
y
=
(
bias_h_1
.
y
+
input_h_1
.
y
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
data_h_0
.
x
+=
residual_h_0
.
x
;
data_h_0
.
y
+=
residual_h_0
.
y
;
data_h_1
.
x
+=
residual_h_1
.
x
;
data_h_1
.
y
+=
residual_h_1
.
y
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
out_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
input
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
(
float
)
residual
[
i
];
out
[
i
]
=
__float2half
(
x_data
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
input
,
const
T
*
residual
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
input
,
residual
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
,
const
float
*
residual
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
,
const
__half
*
residual
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
deepspeed/ops/csrc/transformer/dropout_kernels.hip
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = hiprand_uniform4(&state);
uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[unroll_factor];
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool bwd)
{
assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
hipLaunchKernelGGL(( dropout_kernel_bwd), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, vals, out, mask, seed);
else
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
#else
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask_32[j] = m_32;
Xdata_cast[j] = x_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[j];
bias_f = bias_cast[j % (dim / unroll_factor)];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 out_data;
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask_32[j] = m_32;
out_cast[j] = out_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
deepspeed/ops/csrc/transformer/ds_transformer_cuda.cpp
deleted
100644 → 0
View file @
1b2721ad
#include <torch/extension.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer.h"
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "ds_transformer_cuda.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
unsigned
get_workspace_size
(
unsigned
maxBatchSize
,
unsigned
seq_len
,
unsigned
hidden_size
,
unsigned
intermediate_size
,
unsigned
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
unsigned
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_size
,
unsigned
num_heads
,
unsigned
intermediate_size
,
unsigned
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
//aiss debug 0506
(
T
(
1.0
/
(
sqrt
(
_hidden_size
/
_heads
)))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
#ifndef __HIP_PLATFORM_HCC__
if
(
std
::
is_same
<
T
,
__half
>::
value
)
cublasSetMathMode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
#endif
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
unsigned
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
unsigned
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
unsigned
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_dim
,
unsigned
num_heads
,
unsigned
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
deepspeed/ops/csrc/transformer/ds_transformer_hip.cpp
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer_hip.h"
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
#include "ds_transformer_hip.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
unsigned
get_workspace_size
(
unsigned
maxBatchSize
,
unsigned
seq_len
,
unsigned
hidden_size
,
unsigned
intermediate_size
,
unsigned
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
unsigned
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_size
,
unsigned
num_heads
,
unsigned
intermediate_size
,
unsigned
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
//aiss debug 0506
(
T
(
1.0
/
(
sqrt
(
_hidden_size
/
_heads
)))),
T
(
0.0
),
rocblas_operation_transpose
,
rocblas_operation_none
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
rocblas_operation_none
,
rocblas_operation_none
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
#ifndef __HIP_PLATFORM_HCC__
if
(
std
::
is_same
<
T
,
__half
>::
value
)
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
#endif
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
unsigned
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
unsigned
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
hipStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
unsigned
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_dim
,
unsigned
num_heads
,
unsigned
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
deepspeed/ops/csrc/transformer/gelu_kernels.cu
deleted
100644 → 0
View file @
1b2721ad
#include "custom_cuda_layers.h"
inline
__device__
float
gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
inline
__device__
float
d_gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
float
x2mul
=
x
*
x
*
mul_param
;
float
tan_h
=
tanhf
(
sqrt_param
*
(
x
+
x
*
x2mul
));
float
dg1
=
0.5
f
*
(
1.0
f
+
tan_h
);
float
dg2
=
x
*
0.5
f
*
sqrt_param
*
(
1
-
tan_h
*
tan_h
);
float
dg3
=
dg2
*
3
*
x2mul
;
return
(
dg1
+
dg2
+
dg3
);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__
void
gelu_kernel
(
const
float
*
input
,
float
*
vals
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
gelu_kernel
(
const
__half
*
input
,
__half
*
vals
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
fused_bias_gelu
(
const
float
*
input
,
const
float
*
bias
,
float
*
vals
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
fused_bias_gelu
(
const
__half
*
input
,
const
__half
*
bias
,
__half
*
vals
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
d_gelu_func
(
float
*
d_output
,
const
float
*
gelu_input
,
const
float
*
bias
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
float4
*
d_output_cast
=
reinterpret_cast
<
float4
*>
(
d_output
);
const
float4
*
gelu_input_cast
=
reinterpret_cast
<
const
float4
*>
(
gelu_input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
gelu_input_data
.
x
+=
bias_data
.
x
;
gelu_input_data
.
y
+=
bias_data
.
y
;
gelu_input_data
.
z
+=
bias_data
.
z
;
gelu_input_data
.
w
+=
bias_data
.
w
;
output_data
.
x
*=
d_gelu
(
gelu_input_data
.
x
);
output_data
.
y
*=
d_gelu
(
gelu_input_data
.
y
);
output_data
.
z
*=
d_gelu
(
gelu_input_data
.
z
);
output_data
.
w
*=
d_gelu
(
gelu_input_data
.
w
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
output_data
;
}
}
}
__global__
void
d_gelu_func
(
__half
*
d_output
,
const
__half
*
gelu_input
,
const
__half
*
bias
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
float2
*
d_output_cast
=
reinterpret_cast
<
float2
*>
(
d_output
);
const
float2
*
gelu_input_cast
=
reinterpret_cast
<
const
float2
*>
(
gelu_input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
output_data_half
=
reinterpret_cast
<
__half2
*>
(
&
output_data
);
__half2
*
gelu_input_data_half
=
reinterpret_cast
<
__half2
*>
(
&
gelu_input_data
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
output_half_0
=
__half22float2
(
output_data_half
[
0
]);
float2
output_half_1
=
__half22float2
(
output_data_half
[
1
]);
float2
gelu_input_half_0
=
__half22float2
(
gelu_input_data_half
[
0
]);
float2
gelu_input_half_1
=
__half22float2
(
gelu_input_data_half
[
1
]);
float2
bias_half_0
=
__half22float2
(
bias_half
[
0
]);
float2
bias_half_1
=
__half22float2
(
bias_half
[
1
]);
gelu_input_half_0
.
x
+=
bias_half_0
.
x
;
gelu_input_half_0
.
y
+=
bias_half_0
.
y
;
gelu_input_half_1
.
x
+=
bias_half_1
.
x
;
gelu_input_half_1
.
y
+=
bias_half_1
.
y
;
output_half_0
.
x
*=
d_gelu
(
gelu_input_half_0
.
x
);
output_half_0
.
y
*=
d_gelu
(
gelu_input_half_0
.
y
);
output_half_1
.
x
*=
d_gelu
(
gelu_input_half_1
.
x
);
output_half_1
.
y
*=
d_gelu
(
gelu_input_half_1
.
y
);
float2
result
;
__half2
*
result_half2
=
reinterpret_cast
<
__half2
*>
(
&
result
);
result_half2
[
0
]
=
__float22half2_rn
(
output_half_0
);
result_half2
[
1
]
=
__float22half2_rn
(
output_half_1
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
result
;
}
}
#endif
}
template
<
typename
T
>
void
launch_bias_gelu
(
const
T
*
input
,
const
T
*
bias
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
fused_bias_gelu
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
output
,
intermediate_size
/
4
,
iterations
);
}
template
<
typename
T
>
void
launch_gelu
(
const
T
*
input
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
gelu_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
intermediate_size
/
4
,
iterations
);
}
template
void
launch_bias_gelu
<
float
>(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
const
__half
*
,
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
float
>(
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
__half
>(
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
<
typename
T
>
void
launch_d_gelu
(
T
*
d_output
,
const
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
d_gelu_func
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
d_output
,
input
,
bias
,
intermediate_size
/
4
,
iterations
);
}
template
void
launch_d_gelu
<
float
>(
float
*
,
const
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_d_gelu
<
__half
>(
__half
*
,
const
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
deepspeed/ops/csrc/transformer/gelu_kernels.hip
deleted
100644 → 0
View file @
1b2721ad
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, output, intermediate_size / 4, iterations);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, intermediate_size / 4, iterations);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream,
d_output, input, bias, intermediate_size / 4, iterations);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment