Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
4acf0e01
Commit
4acf0e01
authored
Apr 26, 2023
by
aiss
Browse files
delete hip file
parent
7dd68788
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
11428 deletions
+0
-11428
csrc/transformer/inference/includes/context_hip.h
csrc/transformer/inference/includes/context_hip.h
+0
-185
csrc/transformer/inference/includes/cublas_wrappers_hip.h
csrc/transformer/inference/includes/cublas_wrappers_hip.h
+0
-414
csrc/transformer/inference/includes/custom_hip_layers.h
csrc/transformer/inference/includes/custom_hip_layers.h
+0
-125
csrc/transformer/inference/includes/inference_context_hip.h
csrc/transformer/inference/includes/inference_context_hip.h
+0
-260
csrc/transformer/inference/includes/inference_cublas_wrappers_hip.h
...former/inference/includes/inference_cublas_wrappers_hip.h
+0
-418
csrc/transformer/inference/includes/inference_hip_layers.h
csrc/transformer/inference/includes/inference_hip_layers.h
+0
-223
csrc/transformer/normalize_kernels.hip
csrc/transformer/normalize_kernels.hip
+0
-2123
csrc/transformer/softmax_kernels.hip
csrc/transformer/softmax_kernels.hip
+0
-597
csrc/transformer/transform_kernels.hip
csrc/transformer/transform_kernels.hip
+0
-577
csrc/transformer_bak/cublas_wrappers.cu
csrc/transformer_bak/cublas_wrappers.cu
+0
-403
csrc/transformer_bak/cublas_wrappers.hip
csrc/transformer_bak/cublas_wrappers.hip
+0
-404
csrc/transformer_bak/dropout_kernels.cu
csrc/transformer_bak/dropout_kernels.cu
+0
-868
csrc/transformer_bak/dropout_kernels.hip
csrc/transformer_bak/dropout_kernels.hip
+0
-870
csrc/transformer_bak/ds_transformer_cuda.cpp
csrc/transformer_bak/ds_transformer_cuda.cpp
+0
-1051
csrc/transformer_bak/ds_transformer_hip.cpp
csrc/transformer_bak/ds_transformer_hip.cpp
+0
-1052
csrc/transformer_bak/gelu_kernels.cu
csrc/transformer_bak/gelu_kernels.cu
+0
-330
csrc/transformer_bak/gelu_kernels.hip
csrc/transformer_bak/gelu_kernels.hip
+0
-332
csrc/transformer_bak/general_kernels.cu
csrc/transformer_bak/general_kernels.cu
+0
-411
csrc/transformer_bak/general_kernels.hip
csrc/transformer_bak/general_kernels.hip
+0
-413
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.cu
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.cu
+0
-372
No files found.
csrc/transformer/inference/includes/context_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hiprand/hiprand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
hipError_t error_code = callstr; \
if (error_code != hipSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline
int
DS_GET_BLOCKS
(
const
int
N
)
{
return
std
::
max
(
std
::
min
((
N
+
DS_CUDA_NUM_THREADS
-
1
)
/
DS_CUDA_NUM_THREADS
,
DS_MAXIMUM_NUM_BLOCKS
),
// Use at least 1 block, since CUDA does not allow empty block
1
);
}
class
Context
{
public:
Context
()
:
_workspace
(
nullptr
),
_seed
(
42
),
_curr_offset
(
0
),
_stream
(
0
)
{
hiprandCreateGenerator
(
&
_gen
,
HIPRAND_RNG_PSEUDO_DEFAULT
);
hiprandSetPseudoRandomGeneratorSeed
(
_gen
,
123
);
if
(
rocblas_create_handle
(
&
_cublasHandle
)
!=
rocblas_status_success
)
{
auto
message
=
std
::
string
(
"Fail to create cublas handle."
);
std
::
cerr
<<
message
<<
std
::
endl
;
throw
std
::
runtime_error
(
message
);
}
#ifndef __HIP_PLATFORM_HCC__
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
hipEventCreate
(
&
_comp1_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comp2_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comp_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comm_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
#else
hipEventCreate
(
&
_comp1_event
);
hipEventCreate
(
&
_comp2_event
);
hipEventCreate
(
&
_comp_event
);
hipEventCreate
(
&
_comm_event
);
#endif
}
virtual
~
Context
()
{
rocblas_destroy_handle
(
_cublasHandle
);
hipFree
(
_workspace
);
hipEventDestroy
(
_comp1_event
);
hipEventDestroy
(
_comp2_event
);
hipEventDestroy
(
_comp_event
);
hipEventDestroy
(
_comm_event
);
}
static
Context
&
Instance
()
{
static
Context
_ctx
;
return
_ctx
;
}
void
GenWorkSpace
(
size_t
size
)
{
if
(
!
_workspace
)
{
assert
(
_workspace
==
nullptr
);
hipMalloc
(
&
_workspace
,
size
);
}
else
if
(
_workSpaceSize
<
size
)
{
hipFree
(
_workspace
);
hipMalloc
(
&
_workspace
,
size
);
}
_workSpaceSize
=
size
;
}
hipEvent_t
GetCompEvent
(
int
id
)
{
return
id
==
1
?
_comp1_event
:
_comp2_event
;
}
size_t
get_workspace_size
()
const
{
return
_workSpaceSize
;
}
void
*
GetWorkSpace
()
{
return
_workspace
;
}
inline
unsigned
new_token
(
unsigned
layer_id
)
{
if
(
layer_id
==
0
)
_token_length
++
;
return
_token_length
;
}
inline
void
reset_tokens
(
unsigned
initial_tokens
=
0
)
{
_num_tokens
=
initial_tokens
;
}
//_token_length = 0; }
inline
unsigned
current_tokens
()
const
{
return
_num_tokens
;
}
inline
void
advance_tokens
()
{
_num_tokens
++
;
}
hiprandGenerator_t
&
GetRandGenerator
()
{
return
_gen
;
}
hipStream_t
GetCommStream
(
bool
async_op
=
false
)
{
if
(
!
_comm_stream
)
_comm_stream
=
async_op
?
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
)
:
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
_comm_stream
;
}
hipStream_t
GetCurrentStream
(
bool
other_stream
=
false
)
{
// get current pytorch stream.
if
(
other_stream
)
{
if
(
!
_stream
)
_stream
=
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
);
return
_stream
;
}
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
stream
;
}
rocblas_handle
GetCublasHandle
()
{
return
_cublasHandle
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
offset_inc
)
{
uint64_t
offset
=
_curr_offset
;
_curr_offset
+=
offset_inc
;
return
std
::
pair
<
uint64_t
,
uint64_t
>
(
_seed
,
offset
);
}
void
SetSeed
(
uint64_t
new_seed
)
{
_seed
=
new_seed
;
}
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
GetGemmAlgos
()
const
{
return
_gemm_algos
;
}
inline
void
SynchComp
()
{
hipEventRecord
(
_comp_event
,
_comp_stream
);
hipStreamWaitEvent
(
_comm_stream
,
_comp_event
,
0
);
}
inline
void
SynchComm
()
{
hipEventRecord
(
_comm_event
,
_comm_stream
);
hipStreamWaitEvent
(
_comp_stream
,
_comm_event
,
0
);
}
private:
hiprandGenerator_t
_gen
;
rocblas_handle
_cublasHandle
;
hipEvent_t
_comp_event
;
hipEvent_t
_comm_event
;
void
*
_workspace
;
uint64_t
_seed
;
uint64_t
_curr_offset
;
size_t
_workSpaceSize
;
hipEvent_t
_comp1_event
;
hipEvent_t
_comp2_event
;
hipStream_t
_stream
;
unsigned
_token_length
;
unsigned
_num_tokens
;
std
::
vector
<
std
::
array
<
int
,
3
>>
_gemm_algos
;
hipStream_t
_comp_stream
;
hipStream_t
_comm_stream
;
std
::
unordered_map
<
int
,
int
>
_world_sizes
;
};
csrc/transformer/inference/includes/cublas_wrappers_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <assert.h>
#include <rocblas.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#ifndef __HIP_PLATFORM_HCC__
#include <mma.h>
#endif
#include <stdio.h>
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f32_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f32_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
C
,
rocblas_datatype_f32_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
rocblas_gemmex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
hipR32F
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
hipR32F
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
hipR32F
,
m
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f16_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f16_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
rocblas_gemmex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
hipR16F
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
hipR16F
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
hipR16F
,
m
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f32_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
hipR32F
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
hipR32F
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
hipR32F
,
m
,
stride_C
,
batch
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f16_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
hipR16F
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
hipR16F
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
hipR16F
,
m
,
stride_C
,
batch
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
csrc/transformer/inference/includes/custom_hip_layers.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#pragma once
#ifdef __HIP_PLATFORM_HCC__
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else
#if __CUDA_ARCH__ >= 700
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
#define MAX_WARP_NUM 32
#define WARP_SIZE 32
#define SMs 80
#define MAX_REGISTERS 256
template
<
typename
T
>
void
launch_attn_softmax_v2
(
T
*
vals
,
T
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
hipStream_t
stream
);
// Fused bias add with gelu activation
template
<
typename
T
>
void
launch_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
hidden_size
,
int
batch_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_bias_residual
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
hidden_dim
,
int
mp_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_layer_norm
(
T
*
out
,
T
*
vals
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_residual_layer_norm
(
T
*
norm
,
T
*
res_add
,
T
*
vals
,
T
*
residual
,
const
T
*
bias
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_gptj_residual_add
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
head_size
,
int
mp_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_apply_rotary_pos_emb
(
T
*
mixed_query
,
T
*
key_layer
,
unsigned
head_size
,
unsigned
seq_len
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
unsigned
batch
,
bool
rotate_half
,
bool
rotate_every_two
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_moe_res_matmul
(
T
*
residual
,
T
*
coef
,
T
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
hipStream_t
stream
);
csrc/transformer/inference/includes/inference_context_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <hip/hip_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#define MEGABYTE (1024 * 1024)
#define GIGABYTE (1024 * 1024 * 1024)
// TODO: refactor out
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
hipError_t error_code = callstr; \
if (error_code != hipSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline
int
DS_GET_BLOCKS
(
const
int
N
)
{
return
std
::
max
(
std
::
min
((
N
+
DS_CUDA_NUM_THREADS
-
1
)
/
DS_CUDA_NUM_THREADS
,
DS_MAXIMUM_NUM_BLOCKS
),
// Use at least 1 block, since CUDA does not allow empty block
1
);
}
class
Context
{
public:
Context
()
:
_workspace
(
nullptr
),
_seed
(
42
),
_curr_offset
(
0
),
_stream
(
0
),
_free_memory_size
(
0
),
_num_tokens
(
1
),
_attention_unfused_workspace_offset
(
0
)
{
if
(
rocblas_create_handle
(
&
_cublasHandle
)
!=
rocblas_status_success
)
{
auto
message
=
std
::
string
(
"Fail to create cublas handle."
);
std
::
cerr
<<
message
<<
std
::
endl
;
throw
std
::
runtime_error
(
message
);
}
#ifndef __HIP_PLATFORM_HCC__
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
#endif
hipEventCreate
(
&
_comp1_event
);
hipEventCreate
(
&
_comp2_event
);
hipEventCreate
(
&
_comp_event
);
hipEventCreate
(
&
_comm_event
);
}
virtual
~
Context
()
{
rocblas_destroy_handle
(
_cublasHandle
);
hipFree
(
_workspace
);
hipEventDestroy
(
_comp1_event
);
hipEventDestroy
(
_comp2_event
);
hipEventDestroy
(
_comp_event
);
hipEventDestroy
(
_comm_event
);
}
static
Context
&
Instance
()
{
static
Context
_ctx
;
return
_ctx
;
}
void
GenWorkSpace
(
const
unsigned
&
num_layers
,
const
unsigned
&
num_heads
,
const
size_t
&
batch_size
,
const
size_t
&
prompt_len
,
const
size_t
&
hidden_dim
,
const
unsigned
&
mp_size
,
const
bool
&
external_cache
,
const
size_t
&
elem_size
,
const
unsigned
&
rank
,
unsigned
max_out_tokens
)
{
size_t
total_size
;
if
(
!
_free_memory_size
)
{
hipMemGetInfo
(
&
_free_memory_size
,
&
total_size
);
}
// Flash attention requires padded heads and we'll conservatively allocate
// for that here. Flash attention is only enabled for head size <= 128 right now
const
int
head_size
=
hidden_dim
/
num_heads
;
const
int
padded_head_size
=
head_size
<=
32
?
32
:
(
head_size
<=
64
?
64
:
128
);
const
int
effective_head_size
=
(
head_size
>
128
)
?
head_size
:
padded_head_size
;
size_t
activation_size
=
16
*
(
num_heads
*
effective_head_size
)
*
batch_size
;
// Other sequence length dimension is added when the final workSpaceSize is calculated
size_t
temp_size
=
batch_size
*
num_heads
*
max_out_tokens
*
2
;
size_t
cache_size
=
num_layers
*
batch_size
*
((
num_heads
*
effective_head_size
)
/
mp_size
)
*
2
;
size_t
minimal_requirements
=
temp_size
+
(
_free_memory_size
>
GIGABYTE
?
500
:
100
)
*
MEGABYTE
;
if
(
_free_memory_size
<
minimal_requirements
)
{
printf
(
"Requested:
\t
%lu
\n
Free:
\t
%lu
\n
Total:
\t
%lu
\n
"
,
minimal_requirements
,
_free_memory_size
,
total_size
);
throw
std
::
runtime_error
(
"Workspace can't be allocated, no enough memory."
);
}
_max_seq_len
=
((
_free_memory_size
-
minimal_requirements
)
/
elem_size
)
/
(
activation_size
+
temp_size
+
cache_size
);
_max_seq_len
=
std
::
min
((
size_t
)
max_out_tokens
,
_max_seq_len
);
size_t
workSpaceSize
=
((
external_cache
?
(
activation_size
+
temp_size
)
:
(
activation_size
+
temp_size
+
cache_size
)))
*
_max_seq_len
*
elem_size
;
temp_size
*=
_max_seq_len
*
elem_size
;
if
(
rank
==
0
&&
!
_workspace
)
printf
(
"------------------------------------------------------
\n
"
"Free memory : %f (GigaBytes)
\n
"
"Total memory: %f (GigaBytes)
\n
"
"Requested memory: %f (GigaBytes)
\n
"
"Setting maximum total tokens (input + output) to %lu
\n
"
"------------------------------------------------------
\n
"
,
(
float
)
_free_memory_size
/
GIGABYTE
,
(
float
)
total_size
/
GIGABYTE
,
(
float
)
workSpaceSize
/
GIGABYTE
,
_max_seq_len
);
if
(
!
_workspace
)
{
assert
(
_workspace
==
nullptr
);
hipMalloc
(
&
_workspace
,
workSpaceSize
);
}
else
if
(
_workSpaceSize
<
workSpaceSize
)
{
hipFree
(
_workspace
);
hipMalloc
(
&
_workspace
,
workSpaceSize
);
}
if
(
!
_workspace
)
{
printf
(
"Requested:
\t
%lu
\n
Free:
\t
%lu
\n
Total:
\t
%lu
\n
"
,
workSpaceSize
,
_free_memory_size
,
total_size
);
throw
std
::
runtime_error
(
"Workspace is null."
);
}
_workSpaceSize
=
workSpaceSize
;
_attention_unfused_workspace_offset
=
workSpaceSize
-
temp_size
;
}
inline
size_t
GetMaxTokenLenght
()
const
{
return
_max_seq_len
;
}
hipEvent_t
GetCompEvent
(
int
id
)
{
return
id
==
1
?
_comp1_event
:
_comp2_event
;
}
size_t
get_workspace_size
()
const
{
return
_workSpaceSize
;
}
void
*
GetWorkSpace
()
{
return
_workspace
;
}
void
*
GetAttentionUnfusedWorkspace
()
{
return
(
char
*
)
_workspace
+
_attention_unfused_workspace_offset
;
}
inline
unsigned
new_token
(
unsigned
layer_id
)
{
if
(
layer_id
==
0
)
_token_length
++
;
return
_token_length
;
}
inline
void
reset_tokens
(
unsigned
initial_tokens
=
1
)
{
_num_tokens
=
initial_tokens
;
}
//_token_length = 0; }
inline
unsigned
current_tokens
()
const
{
return
_num_tokens
;
}
inline
void
advance_tokens
()
{
_num_tokens
++
;
}
hipStream_t
GetCommStream
(
bool
async_op
=
false
)
{
if
(
!
_comm_stream
)
_comm_stream
=
async_op
?
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
)
:
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
_comm_stream
;
}
hipStream_t
GetCurrentStream
(
bool
other_stream
=
false
)
{
// get current pytorch stream.
if
(
other_stream
)
{
if
(
!
_stream
)
_stream
=
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
);
return
_stream
;
}
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
stream
;
}
rocblas_handle
GetCublasHandle
()
{
return
_cublasHandle
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
offset_inc
)
{
uint64_t
offset
=
_curr_offset
;
_curr_offset
+=
offset_inc
;
return
std
::
pair
<
uint64_t
,
uint64_t
>
(
_seed
,
offset
);
}
void
SetSeed
(
uint64_t
new_seed
)
{
_seed
=
new_seed
;
}
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
GetGemmAlgos
()
const
{
return
_gemm_algos
;
}
inline
void
SynchComp
()
{
hipEventRecord
(
_comp_event
,
_comp_stream
);
hipStreamWaitEvent
(
_comm_stream
,
_comp_event
,
0
);
}
inline
void
SynchComm
()
{
hipEventRecord
(
_comm_event
,
_comm_stream
);
hipStreamWaitEvent
(
_comp_stream
,
_comm_event
,
0
);
}
private:
rocblas_handle
_cublasHandle
;
hipEvent_t
_comp_event
;
hipEvent_t
_comm_event
;
void
*
_workspace
;
// offset from _workspace for attention unfused memory
size_t
_attention_unfused_workspace_offset
;
uint64_t
_seed
;
uint64_t
_curr_offset
;
size_t
_workSpaceSize
;
size_t
_free_memory_size
;
size_t
_max_seq_len
;
hipEvent_t
_comp1_event
;
hipEvent_t
_comp2_event
;
hipStream_t
_stream
;
unsigned
_token_length
;
unsigned
_num_tokens
;
std
::
vector
<
std
::
array
<
int
,
3
>>
_gemm_algos
;
hipStream_t
_comp_stream
;
hipStream_t
_comm_stream
;
std
::
unordered_map
<
int
,
int
>
_world_sizes
;
};
csrc/transformer/inference/includes/inference_cublas_wrappers_hip.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#include <assert.h>
#include <rocblas.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#ifndef __HIP_PLATFORM_HCC__
#include <mma.h>
#endif
#include <stdio.h>
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f32_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f32_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
C
,
rocblas_datatype_f32_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
rocblas_gemmex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
hipR32F
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
hipR32F
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
hipR32F
,
m
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f16_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f16_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
rocblas_gemmex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
hipR16F
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
hipR16F
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
hipR16F
,
m
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f32_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
hipR32F
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
hipR32F
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
hipR32F
,
m
,
stride_C
,
batch
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f16_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
rocblas_status
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
hipR16F
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
hipR16F
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
hipR16F
,
m
,
stride_C
,
batch
,
hipR32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
rocblas_status_success
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
csrc/transformer/inference/includes/inference_hip_layers.h
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#include "ds_kernel_utils_hip.h"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
#define MAX_WARP_NUM 32
#define WARP_SIZE 32
#define MAX_THREADS 1024
#define SMs 80
#define MAX_REGISTERS 256
template
<
typename
T
>
void
launch_attn_softmax_v2
(
T
*
vals
,
T
*
mask
,
T
*
alibi
,
float
layer_scale
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
int
offset
,
int
mask_stride
,
int
mp_size
,
hipStream_t
stream
);
// Fused bias add with gelu activation
template
<
typename
T
>
void
launch_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_fused_bias_geglu
(
T
*
output
,
const
T
*
activation
,
const
T
*
bias
,
int
rows
,
int
elems_per_row
,
hipStream_t
stream
);
// Fused bias add with relu activation
template
<
typename
T
>
void
launch_bias_relu
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
hidden_size
,
int
batch_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_bias_residual
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
hidden_dim
,
int
mp_size
,
bool
preln
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_fused_ln
(
T
*
output
,
const
T
*
vals
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
rows
,
int
elems_per_row
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_fused_residual_ln
(
T
*
output
,
const
T
*
vals
,
const
T
*
residual
,
const
T
*
bias
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
rows
,
int
elems_per_row
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_fused_residual_ln_store_pre_ln_res
(
T
*
norm_output
,
T
*
res_output
,
const
T
*
vals
,
const
T
*
residual
,
const
T
*
bias
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
rows
,
int
elems_per_row
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_gptj_residual_add
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
head_size
,
int
mp_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_apply_rotary_pos_emb
(
T
*
mixed_query
,
T
*
key_layer
,
unsigned
head_size
,
unsigned
seq_len
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
unsigned
batch
,
bool
rotate_half
,
bool
rotate_every_two
,
hipStream_t
stream
,
int
max_out_tokens
);
template
<
typename
T
>
void
launch_moe_res_matmul
(
T
*
residual
,
T
*
coef
,
T
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
hipStream_t
stream
);
// 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3]
template
<
typename
T
>
void
launch_transform4d_0213
(
T
*
out
,
const
T
*
in
,
int
batch_size
,
int
heads
,
int
seq_length
,
int
hidden_dim
,
hipStream_t
stream
,
int
trans_count
);
template
<
typename
T
>
void
launch_bias_add_transform_0213
(
T
*
outputs
,
T
*
vals
,
T
*
vals1
,
const
T
*
vals2
,
const
T
*
bias
,
int
batch_size
,
int
seq_length
,
unsigned
seq_offset
,
int
seq_length1
,
int
hidden_dim
,
int
heads
,
int
rotary_dim
,
bool
rotate_half
,
bool
rotate_every_two
,
hipStream_t
stream
,
int
trans_count
,
int
max_out_tokens
);
template
<
typename
T
>
void
pad_data
(
T
*
padded_output
,
T
*
output
,
int
bsz
,
int
head_size
,
int
padded_head_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
pad_head_seq
(
T
*
padded_output
,
T
*
output
,
int
bsz
,
int
seq_len
,
int
padded_seq_len
,
int
head_size
,
int
padded_head_size
,
hipStream_t
stream
);
template
<
typename
T
>
void
launch_pad_add_transform_0213
(
T
*
output
,
const
T
*
vals
,
int
batch_size
,
int
hidden_dim
,
int
seq_length
,
int
padded_seq_len
,
int
heads
,
int
padded_head_size
,
hipStream_t
stream
);
csrc/transformer/normalize_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
float* means,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (threadIdx.x == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means)
{
int threads = 128;
dim3 grid_dim(batch_size);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (threadIdx.x == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> WARP_SIZE_BITS;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
sum += g.shfl_down(sum, i);
}
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> WARP_SIZE_BITS)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> WARP_SIZE_BITS); i *= 2) {
variance += g.shfl_down(variance, i);
}
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && threadIdx.x == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars)
{
int threads = 128;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
//hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
// out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 xu[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
__half mean_h = means[row];
__half2 mean_reg = __halves2half2(mean_h, mean_h);
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
xu[iterations] = (vals_hat_h[high_index] - mean_reg);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = iteration_stride >> WARP_SIZE_BITS;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
csrc/transformer/softmax_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <math.h>
#include "custom_hip_layers.h"
#include "general_kernels_hip.h"
namespace cg = cooperative_groups;
dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads)
{
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
// Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// The batch size is typically relatively small, while the sequence length could potentially be
// arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit.
unsigned x = heads * sequence_length / block_compute_size;
unsigned y = batch_size;
return {x, y};
}
// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
const float* attn_mask,
int heads,
int seq_length,
int iterations)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> WARP_SIZE_BITS;
int lane = threadIdx.x & 0x1f;
float4* val_cast = reinterpret_cast<float4*>(vals);
const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
float4 data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float4 mask = attn_mask_cast[mask_offset + data_id];
data[i] = val_cast[data_offset + data_id];
data[i].x += mask.x;
data[i].y += mask.y;
data[i].z += mask.z;
data[i].w += mask.w;
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
data[i].x /= sum;
data[i].y /= sum;
data[i].z /= sum;
data[i].w /= sum;
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
}
}
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(__half* vals,
const __half* attn_mask,
int heads,
int seq_length,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.y;
int row = blockIdx.x;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.x * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> WARP_SIZE_BITS;
int lane = threadIdx.x & 0x1f;
float2* val_cast = reinterpret_cast<float2*>(vals);
const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
val_cast += data_offset;
attn_mask_cast += mask_offset;
float2 low_data[MAX_THREAD_ITERATIONS];
float2 high_data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 data = val_cast[data_id];
float2 mask = attn_mask_cast[data_id];
__half2* data_arr = reinterpret_cast<__half2*>(&data);
__half2* mask_arr = reinterpret_cast<__half2*>(&mask);
low_data[i] = __half22float2(data_arr[0]);
high_data[i] = __half22float2(data_arr[1]);
float2 low_mask = __half22float2(mask_arr[0]);
float2 high_mask = __half22float2(mask_arr[1]);
low_data[i].x += low_mask.x;
low_data[i].y += low_mask.y;
high_data[i].x += high_mask.x;
high_data[i].y += high_mask.y;
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
low_data[i].x /= sum;
low_data[i].y /= sum;
high_data[i].x /= sum;
high_data[i].y /= sum;
result_h[0] = __float22half2_rn(low_data[i]);
result_h[1] = __float22half2_rn(high_data[i]);
val_cast[data_id] = result_f;
}
}
#endif
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, hipStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
const float* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <>
void launch_attn_softmax<__half>(__half* vals,
const __half* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <typename T, int tbSize, int blockStride>
__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> WARP_SIZE_BITS; // warp-count = num_threads / WARP_SIZE (32)
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
? (seq_length + iteration_stride - 1) / iteration_stride
: MAX_THREAD_ITERATIONS);
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> WARP_SIZE_BITS;
int lane = id & 0x1f;
T val_reg[MAX_THREAD_ITERATIONS];
T soft_reg[MAX_THREAD_ITERATIONS];
float grad_reg = 0.0f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
val_reg[i] = out_grad[row * block_width + data_id];
soft_reg[i] = soft_inp[row * block_width + data_id];
grad_reg += ((float)val_reg[i] *
(float)soft_reg[i]); // if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = grad_reg;
b.sync();
if (lane < warp_num) grad_reg = partialSum[lane];
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
grad_reg = g.shfl(grad_reg, id / tbSize);
}
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
out_grad[row * block_width + data_id] = (T)temp;
}
}
}
template <typename T, int ITERATIONS>
__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
const T* output,
int softmax_length)
{
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
output += offset;
T grad_reg[ITERATIONS];
T output_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
output_reg[i] = output[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)output_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
}
}
template <typename T>
void launch_attn_softmax_backward_v2(T* out_grad,
const T* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream)
{
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (seq_length <= 32)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 1>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 12>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 24>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
const __half* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
template void launch_attn_softmax_backward_v2<float>(float* out_grad,
const float* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
csrc/transformer/transform_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define rows_trans 16
#define cols_trans 16
template <typename T>
__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
{
__shared__ T data_block[rows_trans * (cols_trans + 1)];
int r = threadIdx.x / cols_trans;
int c = threadIdx.x % cols_trans;
int m = row_width / cols_trans;
int i = blockIdx.x / m * rows_trans + r;
int j = blockIdx.x % m * cols_trans + c;
int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
for (int k = 0; k < rows_trans; k += row_stride)
data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
__syncthreads();
i = blockIdx.x % m * rows_trans + r;
j = blockIdx.x / m * cols_trans + c;
for (int k = 0; k < rows_trans; k += row_stride)
out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
}
template <>
void Transpose<__half>(const __half* inp_mat,
__half* out_mat,
int rows,
int cols,
hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<__half>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <>
void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<float>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <typename T>
__global__ void transform_0213(T* output,
const T* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void transform_0213<float>(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
}
template <>
__global__ void transform_0213<__half>(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#ifdef HALF_PRECISION_AVAILABLE
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
#endif
}
template <>
void launch_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<float>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 3;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<__half>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#ifdef HALF_PRECISION_AVAILABLE
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
#if defined(__ACC_HALF__)
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
#else
float2 bias_arr_f[4];
float2 vals_arr_f[4];
#pragma unroll
for (int l = 0; l < 4; l++) {
bias_arr_f[l] = __half22float2(bias_half[l]);
vals_arr_f[l] = __half22float2(vals_half[l]);
vals_arr_f[l].x += bias_arr_f[l].x;
vals_arr_f[l].y += bias_arr_f[l].y;
output_half[l] = __float22half2_rn(vals_arr_f[l]);
}
#endif
output_vec[d3] = output_arr;
#endif
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<float>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<__half>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} else {
dim3 block_dim(hidden_dim / heads, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2);
hipLaunchKernelGGL(( bias_add_transform_0213_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads);
}
}
template <typename T>
__global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
#ifdef HALF_PRECISION_AVAILABLE
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
#endif
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
#ifdef HALF_PRECISION_AVAILABLE
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
hipLaunchKernelGGL(( transform4d_0213<float>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, 1);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
hipLaunchKernelGGL(( transform4d_0213<__half>), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim, head_ext);
} else {
dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads, heads, trans_count);
hipLaunchKernelGGL(( transform4d_0213_v2), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim);
}
}
csrc/transformer_bak/cublas_wrappers.cu
deleted
100644 → 0
View file @
7dd68788
#include "cublas_wrappers.h"
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f32_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f32_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
C
,
rocblas_datatype_f32_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_32F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_32F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
CUDA_R_32F
,
m
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_gemm_ex
(
rocblas_handle
handle
,
rocblas_operation
transa
,
rocblas_operation
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_gemm_algo
algo
)
#else
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_ex
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
rocblas_datatype_f16_r
,
(
transa
==
rocblas_operation_none
)
?
m
:
k
,
(
const
void
*
)
B
,
rocblas_datatype_f16_r
,
(
transb
==
rocblas_operation_none
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
(
void
*
)
C
,
rocblas_datatype_f16_r
,
m
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_16F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f32_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f32_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_32F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_32F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
#ifdef __HIP_PLATFORM_HCC__
int
cublas_strided_batched_gemm
(
rocblas_handle
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
rocblas_operation
op_A
,
rocblas_operation
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
rocblas_gemm_algo
algo
)
#else
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status
status
=
rocblas_gemm_strided_batched_ex
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
(
op_A
==
rocblas_operation_none
)
?
m
:
k
,
stride_A
,
B
,
rocblas_datatype_f16_r
,
(
op_B
==
rocblas_operation_none
)
?
k
:
n
,
stride_B
,
beta
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
C
,
rocblas_datatype_f16_r
,
m
,
stride_C
,
batch
,
rocblas_datatype_f32_r
,
algo
,
0
,
0
);
#else
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_16F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_16F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
#endif
#ifdef __HIP_PLATFORM_HCC__
if
(
status
!=
rocblas_status_success
)
{
#else
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
#endif
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
csrc/transformer_bak/cublas_wrappers.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "cublas_wrappers_hip.h"
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f32_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f32_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
rocblas_datatype_f32_r,
m,
C,
rocblas_datatype_f32_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f16_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f16_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
rocblas_datatype_f16_r,
m,
(void*)C,
rocblas_datatype_f16_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f32_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f32_r,
m,
stride_C,
C,
rocblas_datatype_f32_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f16_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f16_r,
m,
stride_C,
C,
rocblas_datatype_f16_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
csrc/transformer_bak/dropout_kernels.cu
deleted
100644 → 0
View file @
7dd68788
#include "custom_cuda_layers.h"
const
int
unroll_factor
=
4
;
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
float
*
out
,
const
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint8_t
m
[
unroll_factor
];
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
int
i
=
j
*
unroll_factor
;
mask
[
i
]
=
(
uint8_t
)
m
[
0
];
mask
[
i
+
1
]
=
(
uint8_t
)
m
[
1
];
mask
[
i
+
2
]
=
(
uint8_t
)
m
[
2
];
mask
[
i
+
3
]
=
(
uint8_t
)
m
[
3
];
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
[
0
];
out
[
i
+
1
]
=
Xdata
[
i
+
1
]
*
scale
*
m
[
1
];
out
[
i
+
2
]
=
Xdata
[
i
+
2
]
*
scale
*
m
[
2
];
out
[
i
+
3
]
=
Xdata
[
i
+
3
]
*
scale
*
m
[
3
];
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
Xdata
[
i
]
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
float
ratio
,
__half
*
out
,
const
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
uint32_t
m_32
;
uint8_t
*
m
=
reinterpret_cast
<
uint8_t
*>
(
&
m_32
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
mask_cast
[
j
]
=
m_32
;
}
#else
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
uint8_t
m
[
unroll_factor
];
float4
rand
=
curand_uniform4
(
&
state
);
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
mask
[
i
]
=
m
[
0
];
mask
[
i
+
1
]
=
m
[
1
];
mask
[
i
+
2
]
=
m
[
2
];
mask
[
i
+
3
]
=
m
[
3
];
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
out
[
i
+
1
]
=
mask
[
i
+
1
]
?
Xdata
[
i
+
1
]
*
scale
:
0.0
;
out
[
i
+
2
]
=
mask
[
i
+
2
]
?
Xdata
[
i
+
2
]
*
scale
:
0.0
;
out
[
i
+
3
]
=
mask
[
i
+
3
]
?
Xdata
[
i
+
3
]
*
scale
:
0.0
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
mask
[
i
]
?
Xdata
[
i
]
*
scale
:
0.0
;
}
}
}
__global__
void
dropout_kernel_bwd
(
const
int
N
,
const
float
ratio
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
#ifdef __STOCHASTIC_MODE__
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_f
=
x_cast
[
j
];
__half2
*
x_h
=
reinterpret_cast
<
__half2
*>
(
&
x_f
);
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
mask_f_data
[
i
]
=
(
float
)(
m
[
i
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
i
++
)
mask_h
[
i
]
=
__float22half2_rn
(
mask_f
[
i
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
x_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_h
[
1
]
*
h_scale
*
mask_h
[
1
];
out_cast
[
j
]
=
result_f
;
}
#else
const
__half
h_scale
=
__float2half
(
scale
);
const
__half
h_zero
=
__float2half
(
0.0
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
int
i
=
j
*
unroll_factor
;
const
__half2
*
vals_half
=
reinterpret_cast
<
const
__half2
*>
(
Xdata
+
i
);
uint8_t
*
m
=
mask
+
i
;
float2
vals_half_f
[
2
];
vals_half_f
[
0
]
=
__half22float2
(
vals_half
[
0
]);
vals_half_f
[
1
]
=
__half22float2
(
vals_half
[
1
]);
out
[
i
]
=
__float2half
(
vals_half_f
[
0
].
x
*
scale
*
m
[
0
]);
out
[
i
+
1
]
=
__float2half
(
vals_half_f
[
0
].
y
*
scale
*
m
[
1
]);
out
[
i
+
2
]
=
__float2half
(
vals_half_f
[
1
].
x
*
scale
*
m
[
2
]);
out
[
i
+
3
]
=
__float2half
(
vals_half_f
[
1
].
y
*
scale
*
m
[
3
]);
}
#endif
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
bwd
)
{
assert
(
unroll_factor
==
4
);
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
/
unroll_factor
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
if
(
dim
>
512
)
{
block_dim
.
x
>>=
1
;
grid_dim
.
x
<<=
1
;
}
uint64_t
inc
=
total_count
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
if
(
bwd
)
dropout_kernel_bwd
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
vals
,
out
,
mask
,
seed
);
else
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
ratio
,
out
,
vals
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
out
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
template
void
launch_dropout
(
__half
*
out
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
int
dim
,
float
ratio
,
cudaStream_t
stream
,
bool
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
float
*
Xdata
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
Xdata
[
i
]
*=
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
__half
*
Xdata
,
uint8_t
*
mask
)
{
const
__half2
h_scale
=
__float2half2_rn
(
scale
);
float2
*
x_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_cast
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
#ifdef __STOCHASTIC_MODE__
__half2
*
x_data_h
=
reinterpret_cast
<
__half2
*>
(
&
x_data
);
__half2
mask_h
[
2
];
float2
mask_f
[
2
];
float
*
mask_f_data
=
&
mask_f
[
0
].
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll_factor
;
i
++
)
*
(
mask_f_data
++
)
=
(
float
)(
m
[
i
]);
mask_h
[
0
]
=
__float22half2_rn
(
mask_f
[
0
]);
mask_h
[
1
]
=
__float22half2_rn
(
mask_f
[
1
]);
result_h
[
0
]
=
x_data_h
[
0
]
*
h_scale
*
mask_h
[
0
];
result_h
[
1
]
=
x_data_h
[
1
]
*
h_scale
*
mask_h
[
1
];
#else
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
#endif
x_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
Xdata
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
float
*
Xdata
,
float
*
out
,
uint8_t
*
mask
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
out
[
i
]
=
Xdata
[
i
]
*
scale
*
mask
[
i
];
}
}
__global__
void
dropout_grad_kernel
(
const
int
N
,
const
float
scale
,
const
__half
*
Xdata
,
__half
*
out
,
uint8_t
*
mask
)
{
const
float2
*
x_cast
=
reinterpret_cast
<
const
float2
*>
(
Xdata
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
const
uint32_t
*
mask_cast
=
reinterpret_cast
<
const
uint32_t
*>
(
mask
);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
/
unroll_factor
)
{
float2
x_data
=
x_cast
[
j
];
uint32_t
m_32
=
mask_cast
[
j
];
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
__half
*
x_data_h
=
reinterpret_cast
<
__half
*>
(
&
x_data
);
float2
result
[
2
];
result
[
0
].
x
=
(
float
)
x_data_h
[
0
]
*
scale
*
m
[
0
];
result
[
0
].
y
=
(
float
)
x_data_h
[
1
]
*
scale
*
m
[
1
];
result
[
1
].
x
=
(
float
)
x_data_h
[
2
]
*
scale
*
m
[
2
];
result
[
1
].
y
=
(
float
)
x_data_h
[
3
]
*
scale
*
m
[
3
];
result_h
[
0
]
=
__float22half2_rn
(
result
[
0
]);
result_h
[
1
]
=
__float22half2_rn
(
result
[
1
]);
out_cast
[
j
]
=
result_f
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
out
[
i
]
=
__float2half
((
float
)
Xdata
[
i
]
*
scale
*
mask
[
i
]);
}
}
}
template
<
typename
T
>
void
launch_dropout_grad
(
T
*
vals_out
,
const
T
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
const
float
scale
=
1.
/
(
1.
-
ratio
);
dropout_grad_kernel
<<<
DS_GET_BLOCKS
(
total_count
/
unroll_factor
),
DS_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
total_count
,
scale
,
vals
,
vals_out
,
mask
);
}
template
void
launch_dropout_grad
(
float
*
,
const
float
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout_grad
(
__half
*
,
const
__half
*
vals
,
uint8_t
*
mask
,
int
total_count
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
bias
,
float
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
Xdata_cast
=
reinterpret_cast
<
float4
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
x_data
=
Xdata_cast
[
j
];
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
x_data
.
x
+=
b_data
.
x
;
x_data
.
y
+=
b_data
.
y
;
x_data
.
z
+=
b_data
.
z
;
x_data
.
w
+=
b_data
.
w
;
x_data
.
x
=
x_data
.
x
*
scale
*
m
[
0
];
x_data
.
y
=
x_data
.
y
*
scale
*
m
[
1
];
x_data
.
z
=
x_data
.
z
*
scale
*
m
[
2
];
x_data
.
w
=
x_data
.
w
*
scale
*
m
[
3
];
mask_32
[
j
]
=
m_32
;
Xdata_cast
[
j
]
=
x_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
Xdata
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
x_data
*
scale
*
m
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
bias
,
__half
*
Xdata
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
Xdata_cast
=
reinterpret_cast
<
float2
*>
(
Xdata
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
data_f
=
Xdata_cast
[
j
];
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
data_h_0
.
x
+=
bias_h_0
.
x
;
data_h_0
.
y
+=
bias_h_0
.
y
;
data_h_1
.
x
+=
bias_h_1
.
x
;
data_h_1
.
y
+=
bias_h_1
.
y
;
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
Xdata_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
Xdata
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
Xdata
[
i
]
=
__float2half
(
x_data
*
scale
*
m
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
float
*
input
,
const
float
*
residual
,
const
float
*
bias
,
float
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float4
*
out_cast
=
reinterpret_cast
<
float4
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
const
float4
*
residual_cast
=
reinterpret_cast
<
const
float4
*>
(
residual
);
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
float4
out_data
;
float4
b_data
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
float4
res_data
=
residual_cast
[
j
];
float4
inp_data
=
input_cast
[
j
];
out_data
.
x
=
(
b_data
.
x
+
inp_data
.
x
);
out_data
.
y
=
(
b_data
.
y
+
inp_data
.
y
);
out_data
.
z
=
(
b_data
.
z
+
inp_data
.
z
);
out_data
.
w
=
(
b_data
.
w
+
inp_data
.
w
);
out_data
.
x
=
out_data
.
x
*
scale
*
m
[
0
];
out_data
.
y
=
out_data
.
y
*
scale
*
m
[
1
];
out_data
.
z
=
out_data
.
z
*
scale
*
m
[
2
];
out_data
.
w
=
out_data
.
w
*
scale
*
m
[
3
];
out_data
.
x
+=
res_data
.
x
;
out_data
.
y
+=
res_data
.
y
;
out_data
.
z
+=
res_data
.
z
;
out_data
.
w
+=
res_data
.
w
;
mask_32
[
j
]
=
m_32
;
out_cast
[
j
]
=
out_data
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
input
[
i
]
+
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
residual
[
i
];
out
[
i
]
=
x_data
;
mask
[
i
]
=
m
;
}
}
}
__global__
void
dropout_kernel
(
const
int
N
,
const
int
dim
,
const
float
ratio
,
const
__half
*
input
,
const
__half
*
residual
,
const
__half
*
bias
,
__half
*
out
,
uint8_t
*
mask
,
std
::
pair
<
uint64_t
,
uint64_t
>
seed
)
{
const
float
scale
=
1.
/
(
1.
-
ratio
);
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
%
(
dim
/
unroll_factor
);
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
.
first
,
idx
,
seed
.
second
,
&
state
);
float2
*
out_cast
=
reinterpret_cast
<
float2
*>
(
out
);
uint32_t
*
mask_32
=
reinterpret_cast
<
uint32_t
*>
(
mask
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
const
float2
*
residual_cast
=
reinterpret_cast
<
const
float2
*>
(
residual
);
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float2
data_f
;
__half2
*
data_h
=
reinterpret_cast
<
__half2
*>
(
&
data_f
);
float2
bias_f
;
__half2
*
bias_h
=
reinterpret_cast
<
__half2
*>
(
&
bias_f
);
float2
residual_f
;
__half2
*
residual_h
=
reinterpret_cast
<
__half2
*>
(
&
residual_f
);
float2
input_f
;
__half2
*
input_h
=
reinterpret_cast
<
__half2
*>
(
&
input_f
);
bias_f
=
bias_cast
[
j
%
(
dim
/
unroll_factor
)];
residual_f
=
residual_cast
[
j
];
input_f
=
input_cast
[
j
];
float2
data_h_0
=
__half22float2
(
data_h
[
0
]);
float2
data_h_1
=
__half22float2
(
data_h
[
1
]);
float2
bias_h_0
=
__half22float2
(
bias_h
[
0
]);
float2
bias_h_1
=
__half22float2
(
bias_h
[
1
]);
float2
residual_h_0
=
__half22float2
(
residual_h
[
0
]);
float2
residual_h_1
=
__half22float2
(
residual_h
[
1
]);
float2
input_h_0
=
__half22float2
(
input_h
[
0
]);
float2
input_h_1
=
__half22float2
(
input_h
[
1
]);
data_h_0
.
x
=
(
bias_h_0
.
x
+
input_h_0
.
x
);
data_h_0
.
y
=
(
bias_h_0
.
y
+
input_h_0
.
y
);
data_h_1
.
x
=
(
bias_h_1
.
x
+
input_h_1
.
x
);
data_h_1
.
y
=
(
bias_h_1
.
y
+
input_h_1
.
y
);
uint32_t
m_32
;
uint8_t
*
m
=
(
uint8_t
*
)
&
m_32
;
m
[
0
]
=
(
uint8_t
)(
rand
.
x
>
ratio
);
m
[
1
]
=
(
uint8_t
)(
rand
.
y
>
ratio
);
m
[
2
]
=
(
uint8_t
)(
rand
.
z
>
ratio
);
m
[
3
]
=
(
uint8_t
)(
rand
.
w
>
ratio
);
data_h_0
.
x
=
__float2half
(
data_h_0
.
x
*
scale
*
m
[
0
]);
data_h_0
.
y
=
__float2half
(
data_h_0
.
y
*
scale
*
m
[
1
]);
data_h_1
.
x
=
__float2half
(
data_h_1
.
x
*
scale
*
m
[
2
]);
data_h_1
.
y
=
__float2half
(
data_h_1
.
y
*
scale
*
m
[
3
]);
data_h_0
.
x
+=
residual_h_0
.
x
;
data_h_0
.
y
+=
residual_h_0
.
y
;
data_h_1
.
x
+=
residual_h_1
.
x
;
data_h_1
.
y
+=
residual_h_1
.
y
;
float2
result_f
;
__half2
*
result_h
=
reinterpret_cast
<
__half2
*>
(
&
result_f
);
result_h
[
0
]
=
__float22half2_rn
(
data_h_0
);
result_h
[
1
]
=
__float22half2_rn
(
data_h_1
);
out_cast
[
j
]
=
result_f
;
mask_32
[
j
]
=
m_32
;
}
int
high_index
=
((((
N
/
unroll_factor
)
-
1
)
/
blockDim
.
x
+
1
)
*
(
unroll_factor
*
blockDim
.
x
))
+
threadIdx
.
x
;
if
(
N
>
high_index
)
{
float4
rand
=
curand_uniform4
(
&
state
);
float
*
rand_data
=
&
(
rand
.
x
);
int
k
=
0
;
for
(
int
i
=
high_index
;
i
<
N
;
i
++
)
{
float
x_data
=
(
float
)
input
[
i
]
+
(
float
)
bias
[
i
%
dim
];
uint8_t
m
=
(
uint8_t
)(
rand_data
[
k
++
]
>
ratio
);
x_data
=
x_data
*
scale
*
m
;
x_data
+=
(
float
)
residual
[
i
];
out
[
i
]
=
__float2half
(
x_data
);
mask
[
i
]
=
m
;
}
}
}
template
<
typename
T
>
void
launch_dropout
(
T
*
out
,
const
T
*
input
,
const
T
*
residual
,
const
T
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
)
{
assert
(
unroll_factor
==
4
);
int
total_count
=
batch
*
dim
/
unroll_factor
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
uint64_t
inc
=
(
batch
*
dim
)
/
grid_dim
.
x
/
block_dim
.
x
;
std
::
pair
<
uint64_t
,
uint64_t
>
seed
=
Context
::
Instance
().
IncrementOffset
(
inc
);
dropout_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
dim
,
ratio
,
input
,
residual
,
bias
,
out
,
mask
,
seed
);
}
template
void
launch_dropout
(
float
*
,
const
float
*
,
const
float
*
residual
,
const
float
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
template
void
launch_dropout
(
__half
*
,
const
__half
*
,
const
__half
*
residual
,
const
__half
*
bias
,
uint8_t
*
mask
,
int
batch
,
int
dim
,
float
ratio
,
cudaStream_t
stream
);
csrc/transformer_bak/dropout_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = hiprand_uniform4(&state);
uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[unroll_factor];
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool bwd)
{
assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
hipLaunchKernelGGL(( dropout_kernel_bwd), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, vals, out, mask, seed);
else
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
#else
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask_32[j] = m_32;
Xdata_cast[j] = x_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[j];
bias_f = bias_cast[j % (dim / unroll_factor)];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 out_data;
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask_32[j] = m_32;
out_cast[j] = out_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
csrc/transformer_bak/ds_transformer_cuda.cpp
deleted
100644 → 0
View file @
7dd68788
#include <torch/extension.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer.h"
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
#include "ds_transformer_cuda.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
unsigned
get_workspace_size
(
unsigned
maxBatchSize
,
unsigned
seq_len
,
unsigned
hidden_size
,
unsigned
intermediate_size
,
unsigned
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
unsigned
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_size
,
unsigned
num_heads
,
unsigned
intermediate_size
,
unsigned
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
//aiss debug 0506
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
(
T
(
1.0
/
(
sqrt
(
_hidden_size
/
_heads
)))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
#ifndef __HIP_PLATFORM_HCC__
if
(
std
::
is_same
<
T
,
__half
>::
value
)
cublasSetMathMode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
#endif
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
unsigned
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
unsigned
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
cublasSetStream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
cudaStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
unsigned
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_dim
,
unsigned
num_heads
,
unsigned
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
csrc/transformer_bak/ds_transformer_hip.cpp
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer_hip.h"
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
#include "ds_transformer_hip.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
unsigned
get_workspace_size
(
unsigned
maxBatchSize
,
unsigned
seq_len
,
unsigned
hidden_size
,
unsigned
intermediate_size
,
unsigned
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
unsigned
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_size
,
unsigned
num_heads
,
unsigned
intermediate_size
,
unsigned
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
//aiss debug 0506
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
(
T
(
1.0
/
(
sqrt
(
_hidden_size
/
_heads
)))),
T
(
0.0
),
rocblas_operation_transpose
,
rocblas_operation_none
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
rocblas_operation_none
,
rocblas_operation_none
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
#ifndef __HIP_PLATFORM_HCC__
if
(
std
::
is_same
<
T
,
__half
>::
value
)
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
#endif
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
unsigned
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
unsigned
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
hipStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
unsigned
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_dim
,
unsigned
num_heads
,
unsigned
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
csrc/transformer_bak/gelu_kernels.cu
deleted
100644 → 0
View file @
7dd68788
#include "custom_cuda_layers.h"
inline
__device__
float
gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
inline
__device__
float
d_gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
float
x2mul
=
x
*
x
*
mul_param
;
float
tan_h
=
tanhf
(
sqrt_param
*
(
x
+
x
*
x2mul
));
float
dg1
=
0.5
f
*
(
1.0
f
+
tan_h
);
float
dg2
=
x
*
0.5
f
*
sqrt_param
*
(
1
-
tan_h
*
tan_h
);
float
dg3
=
dg2
*
3
*
x2mul
;
return
(
dg1
+
dg2
+
dg3
);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__
void
gelu_kernel
(
const
float
*
input
,
float
*
vals
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
gelu_kernel
(
const
__half
*
input
,
__half
*
vals
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
fused_bias_gelu
(
const
float
*
input
,
const
float
*
bias
,
float
*
vals
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
fused_bias_gelu
(
const
__half
*
input
,
const
__half
*
bias
,
__half
*
vals
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
d_gelu_func
(
float
*
d_output
,
const
float
*
gelu_input
,
const
float
*
bias
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
float4
*
d_output_cast
=
reinterpret_cast
<
float4
*>
(
d_output
);
const
float4
*
gelu_input_cast
=
reinterpret_cast
<
const
float4
*>
(
gelu_input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
gelu_input_data
.
x
+=
bias_data
.
x
;
gelu_input_data
.
y
+=
bias_data
.
y
;
gelu_input_data
.
z
+=
bias_data
.
z
;
gelu_input_data
.
w
+=
bias_data
.
w
;
output_data
.
x
*=
d_gelu
(
gelu_input_data
.
x
);
output_data
.
y
*=
d_gelu
(
gelu_input_data
.
y
);
output_data
.
z
*=
d_gelu
(
gelu_input_data
.
z
);
output_data
.
w
*=
d_gelu
(
gelu_input_data
.
w
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
output_data
;
}
}
}
__global__
void
d_gelu_func
(
__half
*
d_output
,
const
__half
*
gelu_input
,
const
__half
*
bias
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
float2
*
d_output_cast
=
reinterpret_cast
<
float2
*>
(
d_output
);
const
float2
*
gelu_input_cast
=
reinterpret_cast
<
const
float2
*>
(
gelu_input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
output_data_half
=
reinterpret_cast
<
__half2
*>
(
&
output_data
);
__half2
*
gelu_input_data_half
=
reinterpret_cast
<
__half2
*>
(
&
gelu_input_data
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
output_half_0
=
__half22float2
(
output_data_half
[
0
]);
float2
output_half_1
=
__half22float2
(
output_data_half
[
1
]);
float2
gelu_input_half_0
=
__half22float2
(
gelu_input_data_half
[
0
]);
float2
gelu_input_half_1
=
__half22float2
(
gelu_input_data_half
[
1
]);
float2
bias_half_0
=
__half22float2
(
bias_half
[
0
]);
float2
bias_half_1
=
__half22float2
(
bias_half
[
1
]);
gelu_input_half_0
.
x
+=
bias_half_0
.
x
;
gelu_input_half_0
.
y
+=
bias_half_0
.
y
;
gelu_input_half_1
.
x
+=
bias_half_1
.
x
;
gelu_input_half_1
.
y
+=
bias_half_1
.
y
;
output_half_0
.
x
*=
d_gelu
(
gelu_input_half_0
.
x
);
output_half_0
.
y
*=
d_gelu
(
gelu_input_half_0
.
y
);
output_half_1
.
x
*=
d_gelu
(
gelu_input_half_1
.
x
);
output_half_1
.
y
*=
d_gelu
(
gelu_input_half_1
.
y
);
float2
result
;
__half2
*
result_half2
=
reinterpret_cast
<
__half2
*>
(
&
result
);
result_half2
[
0
]
=
__float22half2_rn
(
output_half_0
);
result_half2
[
1
]
=
__float22half2_rn
(
output_half_1
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
result
;
}
}
#endif
}
template
<
typename
T
>
void
launch_bias_gelu
(
const
T
*
input
,
const
T
*
bias
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
fused_bias_gelu
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
output
,
intermediate_size
/
4
,
iterations
);
}
template
<
typename
T
>
void
launch_gelu
(
const
T
*
input
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
gelu_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
intermediate_size
/
4
,
iterations
);
}
template
void
launch_bias_gelu
<
float
>(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
const
__half
*
,
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
float
>(
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
__half
>(
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
<
typename
T
>
void
launch_d_gelu
(
T
*
d_output
,
const
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
d_gelu_func
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
d_output
,
input
,
bias
,
intermediate_size
/
4
,
iterations
);
}
template
void
launch_d_gelu
<
float
>(
float
*
,
const
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_d_gelu
<
__half
>(
__half
*
,
const
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
csrc/transformer_bak/gelu_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, output, intermediate_size / 4, iterations);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, intermediate_size / 4, iterations);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream,
d_output, input, bias, intermediate_size / 4, iterations);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
csrc/transformer_bak/general_kernels.cu
deleted
100644 → 0
View file @
7dd68788
#include "general_kernels.h"
namespace
cg
=
cooperative_groups
;
template
<
typename
T
>
__global__
void
column_sum_reduce
(
const
T
*
__restrict__
inp
,
T
*
__restrict__
out
,
int
rows
,
int
width
)
{
__shared__
float
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
y_stride
=
width
*
TILE_DIM
;
float
localSum
=
0
;
// Loop across matrix height
if
(
idx
<
width
)
{
int
offset
=
threadIdx
.
y
*
width
+
idx
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
localSum
+=
(
float
)
inp
[
offset
];
offset
+=
y_stride
;
}
}
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
localSum
;
__syncthreads
();
// Sum the shared buffer.
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
if
(
pos
<
width
)
out
[
pos
]
=
sum
;
}
}
template
<
typename
T
>
void
launch_fuse_transpose_bias_kernel
(
const
T
*
inp
,
T
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
);
template
<
>
void
launch_fuse_transpose_bias_kernel
<
float
>
(
const
float
*
inp
,
float
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
template
<
>
void
launch_fuse_transpose_bias_kernel
<
__half
>
(
const
__half
*
inp
,
__half
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
__global__
void
fused_add2_kernel
(
const
int
N
,
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
)
{
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
val
;
float4
inp1_reg
=
inp1_4
[
j
];
float4
inp2_reg
=
inp2_4
[
j
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
;
out_4
[
j
]
=
val
;
}
}
__global__
void
fused_add2_kernel
(
const
int
N
,
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
)
{
float2
inp1_4
;
float2
inp2_4
;
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
inp1_4
=
inp1_arr
[
j
];
inp2_4
=
inp2_arr
[
j
];
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
inp1_h_f_0
.
x
+=
inp2_h_f_0
.
x
;
inp1_h_f_0
.
y
+=
inp2_h_f_0
.
y
;
inp1_h_f_1
.
x
+=
inp2_h_f_1
.
x
;
inp1_h_f_1
.
y
+=
inp2_h_f_1
.
y
;
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
j
]
=
val_f
;
}
}
template
<
>
void
launch_fused_add2
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
template
<
>
void
launch_fused_add2
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
__global__
void
fused_add3_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add3_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add3
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add3
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
__global__
void
fused_add4_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
const
float4
*
inp4_4
=
reinterpret_cast
<
const
float4
*>
(
inp4
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
float4
inp4_reg
=
inp4_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
+
inp4_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
+
inp4_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
+
inp4_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
+
inp4_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add4_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
const
float2
*
inp4_arr
=
reinterpret_cast
<
const
float2
*>
(
inp4
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
float2
inp4_4
=
inp4_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
__half2
*
inp4_h
=
reinterpret_cast
<
__half2
*>
(
&
inp4_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
float2
inp4_h_f_0
=
__half22float2
(
inp4_h
[
0
]);
float2
inp4_h_f_1
=
__half22float2
(
inp4_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
+
inp4_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
+
inp4_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
+
inp4_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
+
inp4_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add4
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add4
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
csrc/transformer_bak/general_kernels.hip
deleted
100644 → 0
View file @
7dd68788
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "general_kernels_hip.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < width) out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
hipStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.cu
deleted
100644 → 0
View file @
7dd68788
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
namespace
cg
=
cooperative_groups
;
__global__
void
apply_rotary_pos_emb
(
float
*
mixed_query
,
float
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
/
2
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
mixed_query
[
offset
+
lane
];
float
k
=
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
%
2
==
1
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
q_rot
=
g
.
shfl_xor
(
q_rot
,
1
);
k_rot
=
g
.
shfl_xor
(
k_rot
,
1
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
q
;
key_layer
[
offset
+
lane
]
=
k
;
lane
+=
WARP_SIZE
;
}
}
}
__global__
void
apply_rotary_pos_emb
(
__half
*
mixed_query
,
__half
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
#if __CUDA_ARCH__ >= 700
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
/
2
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
(
float
)
mixed_query
[
offset
+
lane
];
float
k
=
(
float
)
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
%
2
==
1
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
q_rot
=
g
.
shfl_xor
(
q_rot
,
1
);
k_rot
=
g
.
shfl_xor
(
k_rot
,
1
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
(
__half
)
q
;
key_layer
[
offset
+
lane
]
=
(
__half
)
k
;
lane
+=
WARP_SIZE
;
}
}
#endif
}
__global__
void
apply_rotary_pos_emb1
(
float
*
mixed_query
,
float
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
/
2
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
mixed_query
[
offset
+
lane
];
float
k
=
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
%
2
==
1
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
q_rot
=
g
.
shfl_xor
(
q_rot
,
1
);
k_rot
=
g
.
shfl_xor
(
k_rot
,
1
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
q
;
key_layer
[
offset
+
lane
]
=
k
;
lane
+=
WARP_SIZE
;
}
}
}
__global__
void
apply_rotary_pos_emb1
(
__half
*
mixed_query
,
__half
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
#if __CUDA_ARCH__ >= 700
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
constexpr
unsigned
mask
[
32
]
=
{
0x1
|
0x1000
,
0x2
|
0x2000
,
0x4
|
0x4000
,
0x8
|
0x8000
,
0x10
|
0x10000
,
0x20
|
0x20000
,
0x40
|
0x40000
,
0x80
|
0x80000
,
0x100
|
0x100000
,
0x200
|
0x200000
,
0x400
|
0x400000
,
0x800
|
0x800000
,
0x1000
|
0x1
,
0x2000
|
0x2
,
0x4000
|
0x4
,
0x8000
|
0x8
,
0x10000
|
0x10
,
0x20000
|
0x20
,
0x40000
|
0x40
,
0x80000
|
0x80
,
0x100000
|
0x100
,
0x200000
|
0x200
,
0x400000
|
0x400
,
0x800000
|
0x800
,
0x1000000
,
0x2000000
,
0x4000000
,
0x8000000
,
0x10000000
,
0x20000000
,
0x40000000
,
0x80000000
};
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
unsigned
half_dim
=
rotary_dim
>>
1
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
%
half_dim
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
(
float
)
mixed_query
[
offset
+
lane
];
float
k
=
(
float
)
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
>
(
half_dim
-
1
)
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
auto
q_rot_tmp
=
lane
<
half_dim
?
__shfl_sync
(
mask
[
lane
],
q_rot
,
lane
+
half_dim
)
:
__shfl_sync
(
mask
[
lane
],
q_rot
,
lane
-
half_dim
);
auto
k_rot_tmp
=
lane
<
half_dim
?
__shfl_sync
(
mask
[
lane
],
k_rot
,
lane
+
half_dim
)
:
__shfl_sync
(
mask
[
lane
],
k_rot
,
lane
-
half_dim
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot_tmp
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot_tmp
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
(
__half
)
q
;
key_layer
[
offset
+
lane
]
=
(
__half
)
k
;
lane
+=
WARP_SIZE
;
}
}
#endif
}
template
<
typename
T
>
void
launch_apply_rotary_pos_emb
(
T
*
mixed_query
,
T
*
key_layer
,
unsigned
head_size
,
unsigned
seq_len
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
unsigned
batch
,
bool
rotate_half
,
bool
rotate_every_two
,
cudaStream_t
stream
)
{
int
total_count
=
batch
*
num_heads
*
seq_len
;
dim3
block_dims
(
1024
);
dim3
grid_dims
((
total_count
-
1
)
/
MAX_WARP_NUM
+
1
);
// (batch_size);
if
(
rotate_every_two
)
apply_rotary_pos_emb
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
mixed_query
,
key_layer
,
rotary_dim
,
seq_len
,
offset
,
num_heads
,
head_size
,
total_count
);
else
if
(
rotate_half
)
apply_rotary_pos_emb1
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
mixed_query
,
key_layer
,
rotary_dim
,
seq_len
,
offset
,
num_heads
,
head_size
,
total_count
);
}
template
void
launch_apply_rotary_pos_emb
<
float
>(
float
*
,
float
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
bool
,
bool
,
cudaStream_t
);
template
void
launch_apply_rotary_pos_emb
<
__half
>(
__half
*
,
__half
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
bool
,
bool
,
cudaStream_t
);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
cudaStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
*/
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment