Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
7d1a83a9
Commit
7d1a83a9
authored
May 25, 2022
by
aiss
Browse files
push Deepspeed 0.6.3 rocm version
parent
ab5534fc
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8713 additions
and
0 deletions
+8713
-0
csrc/transformer_bak/ds_transformer_hip.cpp
csrc/transformer_bak/ds_transformer_hip.cpp
+1052
-0
csrc/transformer_bak/gelu_kernels.cu
csrc/transformer_bak/gelu_kernels.cu
+330
-0
csrc/transformer_bak/gelu_kernels.hip
csrc/transformer_bak/gelu_kernels.hip
+332
-0
csrc/transformer_bak/general_kernels.cu
csrc/transformer_bak/general_kernels.cu
+411
-0
csrc/transformer_bak/general_kernels.hip
csrc/transformer_bak/general_kernels.hip
+413
-0
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.cu
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.cu
+372
-0
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.hip
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.hip
+374
-0
csrc/transformer_bak/inference/csrc/dequantize.cu
csrc/transformer_bak/inference/csrc/dequantize.cu
+110
-0
csrc/transformer_bak/inference/csrc/dequantize.hip
csrc/transformer_bak/inference/csrc/dequantize.hip
+112
-0
csrc/transformer_bak/inference/csrc/gelu.cu
csrc/transformer_bak/inference/csrc/gelu.cu
+525
-0
csrc/transformer_bak/inference/csrc/gelu.hip
csrc/transformer_bak/inference/csrc/gelu.hip
+527
-0
csrc/transformer_bak/inference/csrc/normalize.cu
csrc/transformer_bak/inference/csrc/normalize.cu
+451
-0
csrc/transformer_bak/inference/csrc/normalize.hip
csrc/transformer_bak/inference/csrc/normalize.hip
+453
-0
csrc/transformer_bak/inference/csrc/pt_binding.cpp
csrc/transformer_bak/inference/csrc/pt_binding.cpp
+911
-0
csrc/transformer_bak/inference/csrc/pt_binding_hip.cpp
csrc/transformer_bak/inference/csrc/pt_binding_hip.cpp
+912
-0
csrc/transformer_bak/inference/csrc/softmax.cu
csrc/transformer_bak/inference/csrc/softmax.cu
+432
-0
csrc/transformer_bak/inference/csrc/softmax.hip
csrc/transformer_bak/inference/csrc/softmax.hip
+434
-0
csrc/transformer_bak/inference/includes/context.h
csrc/transformer_bak/inference/includes/context.h
+177
-0
csrc/transformer_bak/inference/includes/context_hip.h
csrc/transformer_bak/inference/includes/context_hip.h
+178
-0
csrc/transformer_bak/inference/includes/cublas_wrappers.h
csrc/transformer_bak/inference/includes/cublas_wrappers.h
+207
-0
No files found.
Too many changes to show.
To preserve performance only
162 of 162+
files are displayed.
Plain diff
Email patch
csrc/transformer_bak/ds_transformer_hip.cpp
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer_hip.h"
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
#include "ds_transformer_hip.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_transformer_layers
;
const
int
init_seq_length
=
128
;
// C++ interface
template
<
typename
T
>
unsigned
get_workspace_size
(
unsigned
maxBatchSize
,
unsigned
seq_len
,
unsigned
hidden_size
,
unsigned
intermediate_size
,
unsigned
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
unsigned
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
workSpacesize
+=
((
std
::
max
)((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate_size
);
}
return
workSpacesize
;
// * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template
<
typename
T
>
BertTransformerLayer
<
T
>::
BertTransformerLayer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_size
,
unsigned
num_heads
,
unsigned
intermediate_size
,
unsigned
seq_length
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
float
layer_norm_eps
,
bool
pre_or_postLayerNorm
,
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
gemm_algos
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
:
_layer_id
(
layer_id
),
_batch_size
(
batch_size
),
_hidden_size
(
hidden_size
),
_heads
(
num_heads
),
_intermediate_size
(
intermediate_size
),
_seq_length
(
seq_length
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_attn_dropout_checkpoint
(
attn_dropout_checkpoint
),
_normalize_invertible
(
normalize_invertible
),
_gelu_checkpoint
(
gelu_checkpoint
),
_stochastic_mode
(
stochastic_mode
),
_stream
(
Context
::
Instance
().
GetCurrentStream
()),
_cublasHandle
(
Context
::
Instance
().
GetCublasHandle
()),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
3
*
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
hidden_size
,
gemm_algos
[
0
])),
_attn_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_layer_norm
(
typename
Normalize_Layer
<
T
>::
Config
(
batch_size
,
seq_length
,
hidden_size
,
layer_norm_eps
,
true
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
_intermediate_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
_intermediate_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_intermediate_size
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
,
_seq_length
)),
_attn_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_layer_output_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
,
_hidden_size
)),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
,
//aiss debug 0506
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
(
T
(
1.0
/
(
sqrt
(
_hidden_size
/
_heads
)))),
T
(
0.0
),
rocblas_operation_transpose
,
rocblas_operation_none
,
gemm_algos
[
3
])),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
_batch_size
*
_heads
,
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
,
T
(
1.0
),
T
(
0.0
),
rocblas_operation_none
,
rocblas_operation_none
,
gemm_algos
[
4
]))
{
assert
(
_hidden_size
%
_heads
==
0
);
Initialize
();
}
template
<
typename
T
>
BertTransformerLayer
<
T
>::~
BertTransformerLayer
()
{
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
#ifndef __HIP_PLATFORM_HCC__
if
(
std
::
is_same
<
T
,
__half
>::
value
)
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
#endif
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Forward
(
unsigned
bsz
,
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_qkvb_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_ob_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
output_b_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
out_ptr
,
T
*
inp_norm_ptr
,
T
*
q_tf_ptr
,
T
*
k_tf_ptr
,
T
*
v_tf_ptr
,
T
*
soft_out_ptr
,
T
*
ctx_bufB_ptr
,
T
*
attn_o_inp_ptr
,
T
*
add_res_ptr
,
T
*
ff1_inp_ptr
,
T
*
gelu_inp_ptr
,
T
*
ff2_inp_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
;
if
(
_normalize_invertible
)
{
add_res_ptr
=
buf_1
+
3
*
small_buf_size
;
buf_2
=
add_res_ptr
;
}
if
(
_gelu_checkpoint
)
buf_2
+=
small_buf_size
;
if
(
_attn_dropout_checkpoint
)
ctx_bufB_ptr
=
(
_gelu_checkpoint
?
(
buf_2
+
(
_intermediate_size
/
_hidden_size
)
*
small_buf_size
)
:
(
buf_1
+
4
*
small_buf_size
));
int
bsz_seq
=
bsz
*
_seq_length
;
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
input_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Forward
(
bsz_seq
,
inp_norm_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
else
_qkv_linear
.
Forward
(
bsz_seq
,
input_ptr
,
attn_qkvw_ptr
,
buf_0
,
_cublasHandle
);
launch_bias_add_transform_0213
<
T
>
(
q_tf_ptr
,
buf_0
,
attn_qkvb_ptr
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
,
3
);
int
bsz_heads
=
bsz
*
_heads
;
// attention scores
_attn_scores
.
Forward
(
bsz_heads
,
soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
Forward
(
bsz
,
soft_out_ptr
,
input_mask_ptr
,
_stream
);
// attn prob dropout.
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr
,
soft_out_ptr
,
_stream
);
// attention context
_attn_context
.
Forward
(
bsz_heads
,
buf_1
,
v_tf_ptr
,
ctx_bufB_ptr
,
_cublasHandle
);
launch_transform4d_0213
<
T
>
(
attn_o_inp_ptr
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
1
);
if
(
_pre_or_postLayerNorm
)
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
buf_1
,
_cublasHandle
);
else
_attn_out_linear
.
Forward
(
bsz_seq
,
attn_o_inp_ptr
,
attn_ow_ptr
,
ff1_inp_ptr
,
_cublasHandle
);
// attn output dropout.
if
(
_pre_or_postLayerNorm
)
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
buf_1
,
input_ptr
,
attn_ob_ptr
,
_stream
);
else
_attn_output_dropout
.
ForwardWithBias
(
bsz_seq
,
add_res_ptr
,
ff1_inp_ptr
,
input_ptr
,
attn_ob_ptr
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
else
_attn_layer_norm
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
add_res_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
_stream
,
true
);
}
_ff1
.
Forward
(
bsz_seq
,
ff1_inp_ptr
,
inter_w_ptr
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
_cublasHandle
);
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
_stream
);
_ff2
.
Forward
(
bsz_seq
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
out_ptr
,
_cublasHandle
);
// layer output dropout.
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
out_ptr
,
out_ptr
,
add_res_ptr
,
output_b_ptr
,
_stream
);
else
_layer_output_dropout
.
ForwardWithBias
(
bsz_seq
,
inp_norm_ptr
,
out_ptr
,
ff1_inp_ptr
,
output_b_ptr
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
ForwardCheckpoint
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
else
_layer_norm
.
Forward
(
bsz_seq
,
out_ptr
,
inp_norm_ptr
,
norm_w_ptr
,
norm_b_ptr
,
_stream
,
true
);
}
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Backward
(
unsigned
bsz
,
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
inp_norm_ptr
,
const
T
*
q_tf_ptr
,
const
T
*
k_tf_ptr
,
const
T
*
v_tf_ptr
,
const
T
*
soft_out_ptr
,
const
T
*
ctx_bufB_ptr
,
const
T
*
attn_o_inp_ptr
,
const
T
*
add_res_ptr
,
const
T
*
ff1_inp_ptr
,
const
T
*
gelu_inp_ptr
,
const
T
*
ff2_inp_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
attn_qkvw_ptr
,
const
T
*
attn_ow_ptr
,
const
T
*
attn_nw_ptr
,
const
T
*
attn_nb_ptr
,
const
T
*
inter_w_ptr
,
const
T
*
inter_b_ptr
,
const
T
*
output_w_ptr
,
const
T
*
norm_w_ptr
,
const
T
*
norm_b_ptr
,
T
*
grad_input_ptr
,
T
*
grad_attn_qkvw_ptr
,
T
*
grad_attn_qkvb_ptr
,
T
*
grad_attn_ow_ptr
,
T
*
grad_attn_ob_ptr
,
T
*
grad_attn_nw_ptr
,
T
*
grad_attn_nb_ptr
,
T
*
grad_inter_w_ptr
,
T
*
grad_inter_b_ptr
,
T
*
grad_output_w_ptr
,
T
*
grad_output_b_ptr
,
T
*
grad_norm_w_ptr
,
T
*
grad_norm_b_ptr
)
{
rocblas_set_stream
(
_cublasHandle
,
_stream
);
if
(
!
_stochastic_mode
)
hipStreamSynchronize
(
_stream
);
T
*
workspace
=
static_cast
<
T
*>
(
Context
::
Instance
().
GetWorkSpace
());
size_t
small_buf_size
=
bsz
*
_seq_length
*
_hidden_size
;
T
*
buf_0
=
workspace
;
T
*
buf_1
=
buf_0
+
small_buf_size
;
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_3
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
hipStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
int
bsz_seq
=
bsz
*
_seq_length
;
int
bsz_heads
=
bsz
*
_heads
;
if
(
!
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
inp_norm_ptr
);
else
_layer_norm
.
Backward
(
bsz_seq
,
grad_output_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
buf_1
,
output_ptr
);
}
if
(
_pre_or_postLayerNorm
)
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
grad_output_ptr
,
_stream
);
else
_layer_output_dropout
.
Backward
(
bsz_seq
,
buf_0
,
buf_1
,
_stream
);
const
T
*
layer_dropout_buf
=
_layer_output_dropout
.
HasDropout
()
?
buf_0
:
(
_pre_or_postLayerNorm
?
grad_output_ptr
:
buf_1
);
if
(
_gelu_checkpoint
)
_gelu
.
ForwardWithBiasAdd
(
bsz_seq
,
ff2_inp_ptr
,
inter_b_ptr
,
buf_2
,
_stream
);
_ff2
.
Backward
(
bsz_seq
,
layer_dropout_buf
,
(
_gelu_checkpoint
?
buf_2
:
ff2_inp_ptr
),
output_w_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
_cublasHandle
,
_stream
,
ff2_buf
);
_gelu
.
Backward
(
bsz_seq
,
ff2_buf
,
(
_gelu_checkpoint
?
ff2_inp_ptr
:
gelu_inp_ptr
),
inter_b_ptr
,
_stream
);
_ff1
.
Backward
(
bsz_seq
,
ff2_buf
,
ff1_inp_ptr
,
inter_w_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
_cublasHandle
,
_stream
,
buf_3
);
if
(
!
_pre_or_postLayerNorm
)
launch_fused_add2
<
T
>
(
buf_2
,
buf_3
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_3
,
grad_output_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
else
{
if
(
_attn_layer_norm
.
UseMean
())
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
add_res_ptr
);
else
_attn_layer_norm
.
Backward
(
bsz_seq
,
buf_2
,
attn_nw_ptr
,
attn_nb_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
streams
,
buf_0
,
ff1_inp_ptr
);
}
_attn_output_dropout
.
Backward
(
bsz_seq
,
buf_2
,
buf_0
,
_stream
);
T
*
attn_output_dropout_buf
=
_attn_output_dropout
.
HasDropout
()
?
buf_2
:
buf_0
;
_attn_out_linear
.
Backward
(
bsz_seq
,
attn_output_dropout_buf
,
attn_o_inp_ptr
,
attn_ow_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
buf_1
);
launch_transform_0213
<
T
>
(
buf_2
,
buf_1
,
bsz
,
_seq_length
,
_hidden_size
,
_heads
,
_stream
);
if
(
_attn_prob_dropout
.
HasDropout
())
{
if
(
_attn_dropout_checkpoint
)
_attn_prob_dropout
.
Forward
(
bsz_heads
*
_seq_length
,
ctx_bufB_ptr_recomp
,
soft_out_ptr
,
_stream
,
true
);
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
(
_attn_dropout_checkpoint
?
ctx_bufB_ptr_recomp
:
ctx_bufB_ptr
),
_cublasHandle
,
buf_3
,
ff2_buf
);
}
else
_attn_context
.
Backward
(
bsz_heads
,
buf_2
,
v_tf_ptr
,
soft_out_ptr
,
_cublasHandle
,
buf_3
,
ff2_buf
);
_attn_prob_dropout
.
Backward
(
bsz_heads
*
_seq_length
,
ff2_buf
,
_stream
);
_softmax
.
Backward
(
bsz
,
ff2_buf
,
soft_out_ptr
,
_stream
);
_attn_scores
.
Backward
(
bsz_heads
,
ff2_buf
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
buf_2
,
buf_1
);
launch_transform4d_0213
(
ff2_buf
,
buf_1
,
bsz
,
_heads
,
_seq_length
,
_hidden_size
,
_stream
,
3
);
if
(
_pre_or_postLayerNorm
)
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
inp_norm_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
else
_qkv_linear
.
Backward
(
bsz_seq
,
ff2_buf
,
input_ptr
,
attn_qkvw_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
buf_2
);
if
(
_pre_or_postLayerNorm
)
{
if
(
_layer_norm
.
UseMean
())
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
input_ptr
);
else
_layer_norm
.
BackwardFusedAdd
(
bsz_seq
,
buf_2
,
buf_0
,
norm_w_ptr
,
norm_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
,
streams
,
grad_input_ptr
,
inp_norm_ptr
);
}
else
launch_fused_add2
<
T
>
(
grad_input_ptr
,
buf_2
,
buf_0
,
bsz
,
_seq_length
,
_hidden_size
,
_stream
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetTrainingMode
(
bool
training
)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout
.
SetTrainingMode
(
training
);
_attn_output_dropout
.
SetTrainingMode
(
training
);
_layer_output_dropout
.
SetTrainingMode
(
training
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetIntermediateBuffers
(
uint8_t
*
attn_prob_dropout_mask_ptr
,
uint8_t
*
attn_output_dropout_mask_ptr
,
uint8_t
*
layer_output_dropout_mask_ptr
,
T
*
attn_layer_norm_var
,
T
*
attn_layer_norm_mean
,
T
*
layer_norm_var
,
T
*
layer_norm_mean
)
{
_attn_prob_dropout
.
SetMask
(
attn_prob_dropout_mask_ptr
);
_attn_output_dropout
.
SetMask
(
attn_output_dropout_mask_ptr
);
_layer_output_dropout
.
SetMask
(
layer_output_dropout_mask_ptr
);
_attn_layer_norm
.
SetVar
(
attn_layer_norm_var
);
_attn_layer_norm
.
SetMean
(
attn_layer_norm_mean
);
_layer_norm
.
SetVar
(
layer_norm_var
);
_layer_norm
.
SetMean
(
layer_norm_mean
);
}
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
SetSeqLength
(
unsigned
seq_len
)
{
_seq_length
=
seq_len
;
_softmax
.
SetSeqLength
(
_seq_length
);
_attn_prob_dropout
.
SetDimension
(
_seq_length
);
_attn_scores
.
SetConfig
(
_seq_length
,
_seq_length
,
_hidden_size
/
_heads
);
_attn_context
.
SetConfig
(
_hidden_size
/
_heads
,
_seq_length
,
_seq_length
);
}
template
<
typename
T
>
int
create_transformer_layer
(
unsigned
layer_id
,
unsigned
batch_size
,
unsigned
hidden_dim
,
unsigned
num_heads
,
unsigned
intermediate_size
,
float
attn_dropout_ratio
,
float
hidden_dropout_ratio
,
float
layer_norm_eps
,
int
seed
,
bool
pre_or_postLayerNorm
,
bool
test_gemm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
,
bool
stochastic_mode
)
{
Context
::
Instance
().
SetSeed
(
seed
);
Context
::
Instance
().
TestGemmFP16
(
test_gemm
,
batch_size
,
init_seq_length
,
num_heads
,
hidden_dim
/
num_heads
);
auto
layer
=
std
::
make_shared
<
BertTransformerLayer
<
T
>>
(
layer_id
,
batch_size
,
hidden_dim
,
num_heads
,
intermediate_size
,
init_seq_length
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
layer_norm_eps
,
pre_or_postLayerNorm
,
Context
::
Instance
().
GetGemmAlgos
(),
attn_dropout_checkpoint
,
normalize_invertible
,
gelu_checkpoint
,
stochastic_mode
);
s_transformer_layers
[
layer_id
]
=
layer
;
std
::
string
dtype
=
(
std
::
is_same
<
T
,
__half
>::
value
)
?
"half"
:
"float"
;
std
::
cout
<<
"layer #"
<<
layer_id
<<
" is created with date type ["
<<
dtype
<<
"]."
<<
std
::
endl
;
return
0
;
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_forward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
,
bool
training_mode
,
bool
prelayernorm
,
bool
attn_dropout_checkpoint
,
bool
normalize_invertible
,
bool
gelu_checkpoint
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
input
.
size
(
0
);
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_qkvb_ptr
=
(
const
T
*
)
attn_qkvb
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_ob_ptr
=
(
const
T
*
)
attn_ob
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
output_b_ptr
=
(
const
T
*
)
output_b
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
auto
output
=
torch
::
empty_like
(
input
);
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
uint8_options
=
torch
::
TensorOptions
()
.
dtype
(
torch
::
kInt8
)
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
false
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
input
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
input
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
inp_norm
=
((
prelayernorm
||
!
normalize_invertible
)
?
torch
::
empty_like
(
input
)
:
output
);
auto
add_res
=
(
normalize_invertible
?
inp_norm
:
torch
::
empty_like
(
input
));
auto
attn_o_inp
=
torch
::
empty_like
(
input
);
auto
qkv_tf
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
0
)
*
3
},
options
);
auto
attn_prob_dropout_mask
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
uint8_options
);
auto
attn_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
layer_output_dropout_mask
=
torch
::
empty
({(
bsz
*
seq_len
),
layer
->
GetHiddenSize
()},
uint8_options
);
auto
attn_layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
attn_layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_var
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
auto
layer_norm_mean
=
torch
::
empty
({(
bsz
*
seq_len
)},
options
);
T
*
inp_norm_ptr
=
(
T
*
)
inp_norm
.
data_ptr
();
T
*
add_res_ptr
=
(
T
*
)
add_res
.
data_ptr
();
T
*
q_tf_ptr
=
(
T
*
)
qkv_tf
.
data_ptr
();
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)k_tf.data_ptr();
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
seq_len
*
output_w
.
size
(
0
));
//(T*)v_tf.data_ptr();
T
*
attn_o_inp_ptr
=
(
T
*
)
attn_o_inp
.
data_ptr
();
torch
::
Tensor
ff2_inp
=
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
);
torch
::
Tensor
gelu_inp
=
(
gelu_checkpoint
?
ff2_inp
:
torch
::
empty
({(
bsz
*
seq_len
),
output_w
.
size
(
1
)},
options
));
auto
ff1_inp
=
torch
::
empty_like
(
input
);
T
*
ff2_inp_ptr
=
(
T
*
)
ff2_inp
.
data_ptr
();
T
*
gelu_inp_ptr
=
(
T
*
)
gelu_inp
.
data_ptr
();
T
*
ff1_inp_ptr
=
(
T
*
)
ff1_inp
.
data_ptr
();
torch
::
Tensor
soft_out
=
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
);
torch
::
Tensor
ctx_bufB
=
(
attn_dropout_checkpoint
?
soft_out
:
torch
::
empty
({(
bsz
*
layer
->
GetNumHeads
()
*
seq_len
),
seq_len
},
options
));
T
*
soft_out_ptr
=
(
T
*
)
soft_out
.
data_ptr
();
T
*
ctx_bufB_ptr
=
(
T
*
)
ctx_bufB
.
data_ptr
();
layer
->
SetTrainingMode
(
training_mode
);
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Forward
(
bsz
,
input_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_qkvb_ptr
,
attn_ow_ptr
,
attn_ob_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
output_b_ptr
,
norm_w_ptr
,
norm_b_ptr
,
out_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
);
return
{
output
,
inp_norm
,
qkv_tf
,
soft_out
,
ctx_bufB
,
attn_o_inp
,
add_res
,
ff1_inp
,
gelu_inp
,
ff2_inp
,
attn_prob_dropout_mask
,
attn_output_dropout_mask
,
layer_output_dropout_mask
,
attn_layer_norm_var
,
attn_layer_norm_mean
,
layer_norm_var
,
layer_norm_mean
};
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
ds_transformer_backward
(
unsigned
layer_id
,
const
torch
::
Tensor
&
grad_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
inp_norm
,
const
torch
::
Tensor
&
qkv_tf
,
const
torch
::
Tensor
&
soft_out
,
const
torch
::
Tensor
&
ctx_bufB
,
const
torch
::
Tensor
&
attn_o_inp
,
const
torch
::
Tensor
&
add_res
,
const
torch
::
Tensor
&
ff1_inp
,
const
torch
::
Tensor
&
gelu_inp
,
const
torch
::
Tensor
&
ff2_inp
,
const
torch
::
Tensor
&
attn_prob_dropout_mask
,
const
torch
::
Tensor
&
attn_output_dropout_mask
,
const
torch
::
Tensor
&
layer_output_dropout_mask
,
const
torch
::
Tensor
&
attn_layer_norm_var
,
const
torch
::
Tensor
&
attn_layer_norm_mean
,
const
torch
::
Tensor
&
layer_norm_var
,
const
torch
::
Tensor
&
layer_norm_mean
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
attn_qkvw
,
const
torch
::
Tensor
&
attn_qkvb
,
const
torch
::
Tensor
&
attn_ow
,
const
torch
::
Tensor
&
attn_ob
,
const
torch
::
Tensor
&
attn_nw
,
const
torch
::
Tensor
&
attn_nb
,
const
torch
::
Tensor
&
inter_w
,
const
torch
::
Tensor
&
inter_b
,
const
torch
::
Tensor
&
output_w
,
const
torch
::
Tensor
&
output_b
,
const
torch
::
Tensor
&
norm_w
,
const
torch
::
Tensor
&
norm_b
)
{
auto
g_output
=
grad_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
CHECK_INPUT
(
inp_norm
);
CHECK_INPUT
(
qkv_tf
);
CHECK_INPUT
(
add_res
);
CHECK_INPUT
(
soft_out
);
CHECK_INPUT
(
ctx_bufB
);
CHECK_INPUT
(
attn_o_inp
);
CHECK_INPUT
(
ff1_inp
);
CHECK_INPUT
(
gelu_inp
);
CHECK_INPUT
(
ff2_inp
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
CHECK_INPUT
(
attn_qkvw
);
CHECK_INPUT
(
attn_qkvb
);
CHECK_INPUT
(
attn_ow
);
CHECK_INPUT
(
attn_ob
);
CHECK_INPUT
(
attn_nw
);
CHECK_INPUT
(
attn_nb
);
CHECK_INPUT
(
inter_w
);
CHECK_INPUT
(
inter_b
);
CHECK_INPUT
(
output_w
);
CHECK_INPUT
(
output_b
);
CHECK_INPUT
(
norm_w
);
CHECK_INPUT
(
norm_b
);
unsigned
bsz
=
g_output
.
size
(
0
);
std
::
shared_ptr
<
BertTransformerLayer
<
T
>>
layer
=
std
::
static_pointer_cast
<
BertTransformerLayer
<
T
>>
(
s_transformer_layers
[
layer_id
]);
unsigned
seq_len
=
layer
->
GetSeqLength
();
if
(
g_output
.
size
(
1
)
!=
seq_len
)
{
seq_len
=
g_output
.
size
(
1
);
layer
->
SetSeqLength
(
seq_len
);
}
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
g_output
.
options
().
dtype
())
.
layout
(
torch
::
kStrided
)
.
device
(
torch
::
kCUDA
)
.
requires_grad
(
true
);
auto
workspace
=
torch
::
empty
({
get_workspace_size
<
T
>
(
bsz
,
seq_len
,
layer
->
GetHiddenSize
(),
layer
->
GetIntermediateSize
(),
layer
->
GetNumHeads
(),
layer
->
IsTrainingMode
(),
layer
->
GeluCheckpoint
())},
options
);
Context
::
Instance
().
SetWorkSpace
((
T
*
)
workspace
.
data_ptr
());
auto
grad_input
=
torch
::
empty_like
(
input
);
auto
grad_attn_qkvw
=
torch
::
empty_like
(
attn_qkvw
);
auto
grad_attn_qkvb
=
torch
::
empty_like
(
attn_qkvb
);
auto
grad_attn_ow
=
torch
::
empty_like
(
attn_ow
);
auto
grad_attn_ob
=
torch
::
empty_like
(
attn_ob
);
auto
grad_attn_nw
=
torch
::
empty_like
(
attn_nw
);
auto
grad_attn_nb
=
torch
::
empty_like
(
attn_nb
);
auto
grad_inter_w
=
torch
::
empty_like
(
inter_w
);
auto
grad_inter_b
=
torch
::
empty_like
(
inter_b
);
auto
grad_output_w
=
torch
::
empty_like
(
output_w
);
auto
grad_output_b
=
torch
::
empty_like
(
output_b
);
auto
grad_norm_w
=
torch
::
empty_like
(
norm_w
);
auto
grad_norm_b
=
torch
::
empty_like
(
norm_b
);
// inputs.
const
T
*
grad_output_ptr
=
(
const
T
*
)
g_output
.
data_ptr
();
const
T
*
input_ptr
=
(
const
T
*
)
input
.
data_ptr
();
const
T
*
output_ptr
=
(
const
T
*
)
output
.
data_ptr
();
const
T
*
inp_norm_ptr
=
(
const
T
*
)
inp_norm
.
data_ptr
();
const
T
*
q_tf_ptr
=
(
const
T
*
)
qkv_tf
.
data_ptr
();
const
T
*
add_res_ptr
=
(
const
T
*
)
add_res
.
data_ptr
();
const
T
*
k_tf_ptr
=
q_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)k_tf.data_ptr();
const
T
*
v_tf_ptr
=
k_tf_ptr
+
(
bsz
*
layer
->
GetSeqLength
()
*
output_w
.
size
(
0
));
//(const T*)v_tf.data_ptr();
const
T
*
ff1_inp_ptr
=
(
const
T
*
)
ff1_inp
.
data_ptr
();
const
T
*
gelu_inp_ptr
=
(
const
T
*
)
gelu_inp
.
data_ptr
();
const
T
*
ff2_inp_ptr
=
(
const
T
*
)
ff2_inp
.
data_ptr
();
const
T
*
ctx_bufB_ptr
=
(
const
T
*
)
ctx_bufB
.
data_ptr
();
const
T
*
soft_out_ptr
=
(
const
T
*
)
soft_out
.
data_ptr
();
const
T
*
attn_o_inp_ptr
=
(
const
T
*
)
attn_o_inp
.
data_ptr
();
const
T
*
input_mask_ptr
=
(
const
T
*
)
input_mask
.
data_ptr
();
const
T
*
attn_qkvw_ptr
=
(
const
T
*
)
attn_qkvw
.
data_ptr
();
const
T
*
attn_ow_ptr
=
(
const
T
*
)
attn_ow
.
data_ptr
();
const
T
*
attn_nw_ptr
=
(
const
T
*
)
attn_nw
.
data_ptr
();
const
T
*
attn_nb_ptr
=
(
const
T
*
)
attn_nb
.
data_ptr
();
const
T
*
inter_w_ptr
=
(
const
T
*
)
inter_w
.
data_ptr
();
const
T
*
inter_b_ptr
=
(
const
T
*
)
inter_b
.
data_ptr
();
const
T
*
output_w_ptr
=
(
const
T
*
)
output_w
.
data_ptr
();
const
T
*
norm_w_ptr
=
(
const
T
*
)
norm_w
.
data_ptr
();
const
T
*
norm_b_ptr
=
(
const
T
*
)
norm_b
.
data_ptr
();
// outputs.
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
T
*
grad_attn_qkvw_ptr
=
(
T
*
)
grad_attn_qkvw
.
data_ptr
();
T
*
grad_attn_qkvb_ptr
=
(
T
*
)
grad_attn_qkvb
.
data_ptr
();
T
*
grad_attn_ow_ptr
=
(
T
*
)
grad_attn_ow
.
data_ptr
();
T
*
grad_attn_ob_ptr
=
(
T
*
)
grad_attn_ob
.
data_ptr
();
T
*
grad_attn_nw_ptr
=
(
T
*
)
grad_attn_nw
.
data_ptr
();
T
*
grad_attn_nb_ptr
=
(
T
*
)
grad_attn_nb
.
data_ptr
();
T
*
grad_inter_w_ptr
=
(
T
*
)
grad_inter_w
.
data_ptr
();
T
*
grad_inter_b_ptr
=
(
T
*
)
grad_inter_b
.
data_ptr
();
T
*
grad_output_w_ptr
=
(
T
*
)
grad_output_w
.
data_ptr
();
T
*
grad_output_b_ptr
=
(
T
*
)
grad_output_b
.
data_ptr
();
T
*
grad_norm_w_ptr
=
(
T
*
)
grad_norm_w
.
data_ptr
();
T
*
grad_norm_b_ptr
=
(
T
*
)
grad_norm_b
.
data_ptr
();
layer
->
SetIntermediateBuffers
((
uint8_t
*
)
attn_prob_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
attn_output_dropout_mask
.
data_ptr
(),
(
uint8_t
*
)
layer_output_dropout_mask
.
data_ptr
(),
(
T
*
)
attn_layer_norm_var
.
data_ptr
(),
(
T
*
)
attn_layer_norm_mean
.
data_ptr
(),
(
T
*
)
layer_norm_var
.
data_ptr
(),
(
T
*
)
layer_norm_mean
.
data_ptr
());
layer
->
Backward
(
bsz
,
grad_output_ptr
,
input_ptr
,
output_ptr
,
inp_norm_ptr
,
q_tf_ptr
,
k_tf_ptr
,
v_tf_ptr
,
soft_out_ptr
,
ctx_bufB_ptr
,
attn_o_inp_ptr
,
add_res_ptr
,
ff1_inp_ptr
,
gelu_inp_ptr
,
ff2_inp_ptr
,
input_mask_ptr
,
attn_qkvw_ptr
,
attn_ow_ptr
,
attn_nw_ptr
,
attn_nb_ptr
,
inter_w_ptr
,
inter_b_ptr
,
output_w_ptr
,
norm_w_ptr
,
norm_b_ptr
,
grad_input_ptr
,
grad_attn_qkvw_ptr
,
grad_attn_qkvb_ptr
,
grad_attn_ow_ptr
,
grad_attn_ob_ptr
,
grad_attn_nw_ptr
,
grad_attn_nb_ptr
,
grad_inter_w_ptr
,
grad_inter_b_ptr
,
grad_output_w_ptr
,
grad_output_b_ptr
,
grad_norm_w_ptr
,
grad_norm_b_ptr
);
return
{
grad_input
,
grad_attn_qkvw
,
grad_attn_qkvb
,
grad_attn_ow
,
grad_attn_ob
,
grad_attn_nw
,
grad_attn_nb
,
grad_inter_w
,
grad_inter_b
,
grad_output_w
,
grad_output_b
,
grad_norm_w
,
grad_norm_b
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_fp32"
,
&
ds_transformer_forward
<
float
>
,
"DeepSpeed Transformer forward with fp32 (CUDA)"
);
m
.
def
(
"forward_fp16"
,
&
ds_transformer_forward
<
__half
>
,
"DeepSpeed Transformer forward with fp16 (CUDA)"
);
m
.
def
(
"backward_fp32"
,
&
ds_transformer_backward
<
float
>
,
"DeepSpeed Transformer backward with fp32 (CUDA)"
);
m
.
def
(
"backward_fp16"
,
&
ds_transformer_backward
<
__half
>
,
"DeepSpeed Transformer backward with fp16 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp32"
,
&
create_transformer_layer
<
float
>
,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)"
);
m
.
def
(
"create_transformer_layer_fp16"
,
&
create_transformer_layer
<
__half
>
,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)"
);
}
csrc/transformer_bak/gelu_kernels.cu
0 → 100644
View file @
7d1a83a9
#include "custom_cuda_layers.h"
inline
__device__
float
gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
inline
__device__
float
d_gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
float
x2mul
=
x
*
x
*
mul_param
;
float
tan_h
=
tanhf
(
sqrt_param
*
(
x
+
x
*
x2mul
));
float
dg1
=
0.5
f
*
(
1.0
f
+
tan_h
);
float
dg2
=
x
*
0.5
f
*
sqrt_param
*
(
1
-
tan_h
*
tan_h
);
float
dg3
=
dg2
*
3
*
x2mul
;
return
(
dg1
+
dg2
+
dg3
);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__
void
gelu_kernel
(
const
float
*
input
,
float
*
vals
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
gelu_kernel
(
const
__half
*
input
,
__half
*
vals
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
fused_bias_gelu
(
const
float
*
input
,
const
float
*
bias
,
float
*
vals
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float4
*
input_cast
=
reinterpret_cast
<
const
float4
*>
(
input
);
float4
*
vals_cast
=
reinterpret_cast
<
float4
*>
(
vals
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
data
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
data
;
}
}
}
__global__
void
fused_bias_gelu
(
const
__half
*
input
,
const
__half
*
bias
,
__half
*
vals
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
const
float2
*
input_cast
=
reinterpret_cast
<
const
float2
*>
(
input
);
float2
*
vals_cast
=
reinterpret_cast
<
float2
*>
(
vals
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
vals_vec
=
input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
vals_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
vals_vec
;
}
}
#endif
}
__global__
void
d_gelu_func
(
float
*
d_output
,
const
float
*
gelu_input
,
const
float
*
bias
,
int
row_stride
,
int
iterations
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
float4
*
d_output_cast
=
reinterpret_cast
<
float4
*>
(
d_output
);
const
float4
*
gelu_input_cast
=
reinterpret_cast
<
const
float4
*>
(
gelu_input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float4
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float4
bias_data
=
bias_cast
[
i
*
loop_stride
+
id
];
gelu_input_data
.
x
+=
bias_data
.
x
;
gelu_input_data
.
y
+=
bias_data
.
y
;
gelu_input_data
.
z
+=
bias_data
.
z
;
gelu_input_data
.
w
+=
bias_data
.
w
;
output_data
.
x
*=
d_gelu
(
gelu_input_data
.
x
);
output_data
.
y
*=
d_gelu
(
gelu_input_data
.
y
);
output_data
.
z
*=
d_gelu
(
gelu_input_data
.
z
);
output_data
.
w
*=
d_gelu
(
gelu_input_data
.
w
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
output_data
;
}
}
}
__global__
void
d_gelu_func
(
__half
*
d_output
,
const
__half
*
gelu_input
,
const
__half
*
bias
,
int
row_stride
,
int
iterations
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
loop_stride
=
blockDim
.
x
;
float2
*
d_output_cast
=
reinterpret_cast
<
float2
*>
(
d_output
);
const
float2
*
gelu_input_cast
=
reinterpret_cast
<
const
float2
*>
(
gelu_input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
#pragma unroll
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
if
(
i
*
loop_stride
+
id
<
row_stride
)
{
float2
output_data
=
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
gelu_input_data
=
gelu_input_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
];
float2
bias_vec
=
bias_cast
[
i
*
loop_stride
+
id
];
__half2
*
output_data_half
=
reinterpret_cast
<
__half2
*>
(
&
output_data
);
__half2
*
gelu_input_data_half
=
reinterpret_cast
<
__half2
*>
(
&
gelu_input_data
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
output_half_0
=
__half22float2
(
output_data_half
[
0
]);
float2
output_half_1
=
__half22float2
(
output_data_half
[
1
]);
float2
gelu_input_half_0
=
__half22float2
(
gelu_input_data_half
[
0
]);
float2
gelu_input_half_1
=
__half22float2
(
gelu_input_data_half
[
1
]);
float2
bias_half_0
=
__half22float2
(
bias_half
[
0
]);
float2
bias_half_1
=
__half22float2
(
bias_half
[
1
]);
gelu_input_half_0
.
x
+=
bias_half_0
.
x
;
gelu_input_half_0
.
y
+=
bias_half_0
.
y
;
gelu_input_half_1
.
x
+=
bias_half_1
.
x
;
gelu_input_half_1
.
y
+=
bias_half_1
.
y
;
output_half_0
.
x
*=
d_gelu
(
gelu_input_half_0
.
x
);
output_half_0
.
y
*=
d_gelu
(
gelu_input_half_0
.
y
);
output_half_1
.
x
*=
d_gelu
(
gelu_input_half_1
.
x
);
output_half_1
.
y
*=
d_gelu
(
gelu_input_half_1
.
y
);
float2
result
;
__half2
*
result_half2
=
reinterpret_cast
<
__half2
*>
(
&
result
);
result_half2
[
0
]
=
__float22half2_rn
(
output_half_0
);
result_half2
[
1
]
=
__float22half2_rn
(
output_half_1
);
d_output_cast
[
row
*
row_stride
+
i
*
loop_stride
+
id
]
=
result
;
}
}
#endif
}
template
<
typename
T
>
void
launch_bias_gelu
(
const
T
*
input
,
const
T
*
bias
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
fused_bias_gelu
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
output
,
intermediate_size
/
4
,
iterations
);
}
template
<
typename
T
>
void
launch_gelu
(
const
T
*
input
,
T
*
output
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
gelu_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
intermediate_size
/
4
,
iterations
);
}
template
void
launch_bias_gelu
<
float
>(
const
float
*
,
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
const
__half
*
,
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
float
>(
const
float
*
,
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_gelu
<
__half
>(
const
__half
*
,
__half
*
,
int
,
int
,
cudaStream_t
);
template
<
typename
T
>
void
launch_d_gelu
(
T
*
d_output
,
const
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
iterations
=
(
intermediate_size
+
1023
)
/
1024
;
int
threads
=
(
intermediate_size
-
1
)
/
(
iterations
*
4
)
+
1
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
batch_size
);
d_gelu_func
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
d_output
,
input
,
bias
,
intermediate_size
/
4
,
iterations
);
}
template
void
launch_d_gelu
<
float
>(
float
*
,
const
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_d_gelu
<
__half
>(
__half
*
,
const
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
csrc/transformer_bak/gelu_kernels.hip
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, output, intermediate_size / 4, iterations);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, intermediate_size / 4, iterations);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream,
d_output, input, bias, intermediate_size / 4, iterations);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
csrc/transformer_bak/general_kernels.cu
0 → 100644
View file @
7d1a83a9
#include "general_kernels.h"
namespace
cg
=
cooperative_groups
;
template
<
typename
T
>
__global__
void
column_sum_reduce
(
const
T
*
__restrict__
inp
,
T
*
__restrict__
out
,
int
rows
,
int
width
)
{
__shared__
float
tile
[
TILE_DIM
][
TILE_DIM
+
1
];
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
TILE_DIM
>
g
=
cg
::
tiled_partition
<
TILE_DIM
>
(
b
);
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
y_stride
=
width
*
TILE_DIM
;
float
localSum
=
0
;
// Loop across matrix height
if
(
idx
<
width
)
{
int
offset
=
threadIdx
.
y
*
width
+
idx
;
for
(
int
r
=
threadIdx
.
y
;
r
<
rows
;
r
+=
TILE_DIM
)
{
localSum
+=
(
float
)
inp
[
offset
];
offset
+=
y_stride
;
}
}
tile
[
threadIdx
.
x
][
threadIdx
.
y
]
=
localSum
;
__syncthreads
();
// Sum the shared buffer.
float
sum
=
tile
[
threadIdx
.
y
][
threadIdx
.
x
];
#ifndef __STOCHASTIC_MODE__
__syncthreads
();
#endif
for
(
int
i
=
1
;
i
<
TILE_DIM
;
i
<<=
1
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
threadIdx
.
x
==
0
)
{
int
pos
=
blockIdx
.
x
*
TILE_DIM
+
threadIdx
.
y
;
if
(
pos
<
width
)
out
[
pos
]
=
sum
;
}
}
template
<
typename
T
>
void
launch_fuse_transpose_bias_kernel
(
const
T
*
inp
,
T
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
);
template
<
>
void
launch_fuse_transpose_bias_kernel
<
float
>
(
const
float
*
inp
,
float
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
float
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
template
<
>
void
launch_fuse_transpose_bias_kernel
<
__half
>
(
const
__half
*
inp
,
__half
*
out
,
int
rows
,
int
cols
,
cudaStream_t
stream
)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3
grid_dim
((
cols
-
1
)
/
TILE_DIM
+
1
);
dim3
block_dim
(
TILE_DIM
,
TILE_DIM
);
column_sum_reduce
<
__half
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
inp
,
out
,
rows
,
cols
);
}
__global__
void
fused_add2_kernel
(
const
int
N
,
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
)
{
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
float4
val
;
float4
inp1_reg
=
inp1_4
[
j
];
float4
inp2_reg
=
inp2_4
[
j
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
;
out_4
[
j
]
=
val
;
}
}
__global__
void
fused_add2_kernel
(
const
int
N
,
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
)
{
float2
inp1_4
;
float2
inp2_4
;
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
CUDA_1D_KERNEL_LOOP
(
j
,
N
)
{
inp1_4
=
inp1_arr
[
j
];
inp2_4
=
inp2_arr
[
j
];
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
inp1_h_f_0
.
x
+=
inp2_h_f_0
.
x
;
inp1_h_f_0
.
y
+=
inp2_h_f_0
.
y
;
inp1_h_f_1
.
x
+=
inp2_h_f_1
.
x
;
inp1_h_f_1
.
y
+=
inp2_h_f_1
.
y
;
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
j
]
=
val_f
;
}
}
template
<
>
void
launch_fused_add2
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
template
<
>
void
launch_fused_add2
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
int
batch_size
,
int
seq_length
,
int
hidden_dim
,
cudaStream_t
&
stream
)
{
int
total_count
=
batch_size
*
seq_length
*
hidden_dim
/
4
;
dim3
grid_dim
=
DS_GET_BLOCKS
(
total_count
);
//(batch_size * seq_length);
dim3
block_dim
=
DS_CUDA_NUM_THREADS
;
//(hidden_dim / 4);
fused_add2_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
total_count
,
out
,
inp1
,
inp2
);
}
__global__
void
fused_add3_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add3_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add3
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add3
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add3_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
__global__
void
fused_add4_kernel
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float4
*
inp1_4
=
reinterpret_cast
<
const
float4
*>
(
inp1
);
const
float4
*
inp2_4
=
reinterpret_cast
<
const
float4
*>
(
inp2
);
const
float4
*
inp3_4
=
reinterpret_cast
<
const
float4
*>
(
inp3
);
const
float4
*
inp4_4
=
reinterpret_cast
<
const
float4
*>
(
inp4
);
float4
*
out_4
=
reinterpret_cast
<
float4
*>
(
out
);
float4
val
;
float4
inp1_reg
=
inp1_4
[
row
*
row_stride
+
id
];
float4
inp2_reg
=
inp2_4
[
row
*
row_stride
+
id
];
float4
inp3_reg
=
inp3_4
[
row
*
row_stride
+
id
];
float4
inp4_reg
=
inp4_4
[
row
*
row_stride
+
id
];
val
.
x
=
inp1_reg
.
x
+
inp2_reg
.
x
+
inp3_reg
.
x
+
inp4_reg
.
x
;
val
.
y
=
inp1_reg
.
y
+
inp2_reg
.
y
+
inp3_reg
.
y
+
inp4_reg
.
y
;
val
.
z
=
inp1_reg
.
z
+
inp2_reg
.
z
+
inp3_reg
.
z
+
inp4_reg
.
z
;
val
.
w
=
inp1_reg
.
w
+
inp2_reg
.
w
+
inp3_reg
.
w
+
inp4_reg
.
w
;
out_4
[
row
*
row_stride
+
id
]
=
val
;
}
__global__
void
fused_add4_kernel
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
size
,
int
row_stride
)
{
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
const
float2
*
inp1_arr
=
reinterpret_cast
<
const
float2
*>
(
inp1
);
const
float2
*
inp2_arr
=
reinterpret_cast
<
const
float2
*>
(
inp2
);
const
float2
*
inp3_arr
=
reinterpret_cast
<
const
float2
*>
(
inp3
);
const
float2
*
inp4_arr
=
reinterpret_cast
<
const
float2
*>
(
inp4
);
float2
inp1_4
=
inp1_arr
[
row
*
row_stride
+
id
];
float2
inp2_4
=
inp2_arr
[
row
*
row_stride
+
id
];
float2
inp3_4
=
inp3_arr
[
row
*
row_stride
+
id
];
float2
inp4_4
=
inp4_arr
[
row
*
row_stride
+
id
];
__half2
*
inp1_h
=
reinterpret_cast
<
__half2
*>
(
&
inp1_4
);
__half2
*
inp2_h
=
reinterpret_cast
<
__half2
*>
(
&
inp2_4
);
__half2
*
inp3_h
=
reinterpret_cast
<
__half2
*>
(
&
inp3_4
);
__half2
*
inp4_h
=
reinterpret_cast
<
__half2
*>
(
&
inp4_4
);
float2
inp1_h_f_0
=
__half22float2
(
inp1_h
[
0
]);
float2
inp1_h_f_1
=
__half22float2
(
inp1_h
[
1
]);
float2
inp2_h_f_0
=
__half22float2
(
inp2_h
[
0
]);
float2
inp2_h_f_1
=
__half22float2
(
inp2_h
[
1
]);
float2
inp3_h_f_0
=
__half22float2
(
inp3_h
[
0
]);
float2
inp3_h_f_1
=
__half22float2
(
inp3_h
[
1
]);
float2
inp4_h_f_0
=
__half22float2
(
inp4_h
[
0
]);
float2
inp4_h_f_1
=
__half22float2
(
inp4_h
[
1
]);
inp1_h_f_0
.
x
+=
(
inp2_h_f_0
.
x
+
inp3_h_f_0
.
x
+
inp4_h_f_0
.
x
);
inp1_h_f_0
.
y
+=
(
inp2_h_f_0
.
y
+
inp3_h_f_0
.
y
+
inp4_h_f_0
.
y
);
inp1_h_f_1
.
x
+=
(
inp2_h_f_1
.
x
+
inp3_h_f_1
.
x
+
inp4_h_f_1
.
x
);
inp1_h_f_1
.
y
+=
(
inp2_h_f_1
.
y
+
inp3_h_f_1
.
y
+
inp4_h_f_1
.
y
);
float2
val_f
;
__half2
*
val_h
=
reinterpret_cast
<
__half2
*>
(
&
val_f
);
val_h
[
0
]
=
__float22half2_rn
(
inp1_h_f_0
);
val_h
[
1
]
=
__float22half2_rn
(
inp1_h_f_1
);
float2
*
out_4
=
reinterpret_cast
<
float2
*>
(
out
);
out_4
[
row
*
row_stride
+
id
]
=
val_f
;
}
template
<
>
void
launch_fused_add4
<
float
>
(
float
*
out
,
const
float
*
inp1
,
const
float
*
inp2
,
const
float
*
inp3
,
const
float
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
template
<
>
void
launch_fused_add4
<
__half
>
(
__half
*
out
,
const
__half
*
inp1
,
const
__half
*
inp2
,
const
__half
*
inp3
,
const
__half
*
inp4
,
int
batch_size
,
int
seq_length
,
int
hidden_size
,
cudaStream_t
&
stream
)
{
dim3
grid_dim
(
batch_size
*
seq_length
);
dim3
block_dim
(
hidden_size
/
4
);
fused_add4_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
inp1
,
inp2
,
inp3
,
inp4
,
(
batch_size
*
seq_length
*
hidden_size
),
hidden_size
/
4
);
}
csrc/transformer_bak/general_kernels.hip
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "general_kernels_hip.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < width) out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
hipStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.cu
0 → 100644
View file @
7d1a83a9
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
namespace
cg
=
cooperative_groups
;
__global__
void
apply_rotary_pos_emb
(
float
*
mixed_query
,
float
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
/
2
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
mixed_query
[
offset
+
lane
];
float
k
=
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
%
2
==
1
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
q_rot
=
g
.
shfl_xor
(
q_rot
,
1
);
k_rot
=
g
.
shfl_xor
(
k_rot
,
1
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
q
;
key_layer
[
offset
+
lane
]
=
k
;
lane
+=
WARP_SIZE
;
}
}
}
__global__
void
apply_rotary_pos_emb
(
__half
*
mixed_query
,
__half
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
#if __CUDA_ARCH__ >= 700
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
/
2
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
(
float
)
mixed_query
[
offset
+
lane
];
float
k
=
(
float
)
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
%
2
==
1
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
q_rot
=
g
.
shfl_xor
(
q_rot
,
1
);
k_rot
=
g
.
shfl_xor
(
k_rot
,
1
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
(
__half
)
q
;
key_layer
[
offset
+
lane
]
=
(
__half
)
k
;
lane
+=
WARP_SIZE
;
}
}
#endif
}
__global__
void
apply_rotary_pos_emb1
(
float
*
mixed_query
,
float
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
/
2
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
mixed_query
[
offset
+
lane
];
float
k
=
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
%
2
==
1
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
q_rot
=
g
.
shfl_xor
(
q_rot
,
1
);
k_rot
=
g
.
shfl_xor
(
k_rot
,
1
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
q
;
key_layer
[
offset
+
lane
]
=
k
;
lane
+=
WARP_SIZE
;
}
}
}
__global__
void
apply_rotary_pos_emb1
(
__half
*
mixed_query
,
__half
*
key_layer
,
unsigned
rotary_dim
,
unsigned
seq_len
,
unsigned
seq_offset
,
unsigned
num_heads
,
unsigned
head_size
,
unsigned
total_count
)
{
#if __CUDA_ARCH__ >= 700
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
lane
=
id
&
0x1f
;
unsigned
head_id
=
blockIdx
.
x
*
MAX_WARP_NUM
+
gid
;
unsigned
offset
=
head_id
*
head_size
;
constexpr
unsigned
mask
[
32
]
=
{
0x1
|
0x1000
,
0x2
|
0x2000
,
0x4
|
0x4000
,
0x8
|
0x8000
,
0x10
|
0x10000
,
0x20
|
0x20000
,
0x40
|
0x40000
,
0x80
|
0x80000
,
0x100
|
0x100000
,
0x200
|
0x200000
,
0x400
|
0x400000
,
0x800
|
0x800000
,
0x1000
|
0x1
,
0x2000
|
0x2
,
0x4000
|
0x4
,
0x8000
|
0x8
,
0x10000
|
0x10
,
0x20000
|
0x20
,
0x40000
|
0x40
,
0x80000
|
0x80
,
0x100000
|
0x100
,
0x200000
|
0x200
,
0x400000
|
0x400
,
0x800000
|
0x800
,
0x1000000
,
0x2000000
,
0x4000000
,
0x8000000
,
0x10000000
,
0x20000000
,
0x40000000
,
0x80000000
};
unsigned
seq_id
=
(
head_id
/
num_heads
)
%
seq_len
+
seq_offset
;
unsigned
half_dim
=
rotary_dim
>>
1
;
if
(
head_id
<
total_count
)
{
while
(
lane
<
rotary_dim
)
{
float
inv_freq
=
(
float
)((
lane
%
half_dim
)
*
2
)
/
(
float
)
rotary_dim
;
inv_freq
=
1.0
/
powf
(
10000.0
,
inv_freq
)
*
(
float
)
seq_id
;
float
q
=
(
float
)
mixed_query
[
offset
+
lane
];
float
k
=
(
float
)
key_layer
[
offset
+
lane
];
float
rotary_sign
=
(
lane
>
(
half_dim
-
1
)
?
-
1.0
:
1.0
);
float
q_rot
=
(
q
*
rotary_sign
);
float
k_rot
=
(
k
*
rotary_sign
);
auto
q_rot_tmp
=
lane
<
half_dim
?
__shfl_sync
(
mask
[
lane
],
q_rot
,
lane
+
half_dim
)
:
__shfl_sync
(
mask
[
lane
],
q_rot
,
lane
-
half_dim
);
auto
k_rot_tmp
=
lane
<
half_dim
?
__shfl_sync
(
mask
[
lane
],
k_rot
,
lane
+
half_dim
)
:
__shfl_sync
(
mask
[
lane
],
k_rot
,
lane
-
half_dim
);
q
=
q
*
cosf
(
inv_freq
)
+
q_rot_tmp
*
sinf
(
inv_freq
);
k
=
k
*
cosf
(
inv_freq
)
+
k_rot_tmp
*
sinf
(
inv_freq
);
mixed_query
[
offset
+
lane
]
=
(
__half
)
q
;
key_layer
[
offset
+
lane
]
=
(
__half
)
k
;
lane
+=
WARP_SIZE
;
}
}
#endif
}
template
<
typename
T
>
void
launch_apply_rotary_pos_emb
(
T
*
mixed_query
,
T
*
key_layer
,
unsigned
head_size
,
unsigned
seq_len
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
unsigned
batch
,
bool
rotate_half
,
bool
rotate_every_two
,
cudaStream_t
stream
)
{
int
total_count
=
batch
*
num_heads
*
seq_len
;
dim3
block_dims
(
1024
);
dim3
grid_dims
((
total_count
-
1
)
/
MAX_WARP_NUM
+
1
);
// (batch_size);
if
(
rotate_every_two
)
apply_rotary_pos_emb
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
mixed_query
,
key_layer
,
rotary_dim
,
seq_len
,
offset
,
num_heads
,
head_size
,
total_count
);
else
if
(
rotate_half
)
apply_rotary_pos_emb1
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
mixed_query
,
key_layer
,
rotary_dim
,
seq_len
,
offset
,
num_heads
,
head_size
,
total_count
);
}
template
void
launch_apply_rotary_pos_emb
<
float
>(
float
*
,
float
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
bool
,
bool
,
cudaStream_t
);
template
void
launch_apply_rotary_pos_emb
<
__half
>(
__half
*
,
__half
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
bool
,
bool
,
cudaStream_t
);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
cudaStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
*/
csrc/transformer_bak/inference/csrc/apply_rotary_pos_emb.hip
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned half_dim = rotary_dim >> 1;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
hipLaunchKernelGGL(( apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
else if (rotate_half)
hipLaunchKernelGGL(( apply_rotary_pos_emb1), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
hipLaunchKernelGGL((
apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
*/
csrc/transformer_bak/inference/csrc/dequantize.cu
0 → 100644
View file @
7d1a83a9
#include "custom_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__
void
dequantize_kernel
(
float
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
int
output_size
,
int
hidden_dim
,
int
groups
,
int
merge_count
)
{
unsigned
merge_hidden
=
hidden_dim
>>
merge_count
;
unsigned
quantization_stride
=
(
merge_hidden
*
output_size
)
/
groups
;
unsigned
bid
=
blockIdx
.
x
;
unsigned
tid
=
threadIdx
.
x
;
while
(
tid
<
output_size
)
{
unsigned
w_index
=
bid
/
merge_hidden
;
unsigned
q_index
=
tid
+
bid
*
output_size
;
auto
q
=
input
[
q_index
];
unsigned
merge_hidden_total
=
w_index
*
merge_hidden
;
unsigned
scale_index
=
((((
bid
-
merge_hidden_total
)
+
tid
*
merge_hidden
)
/
quantization_stride
)
<<
merge_count
)
+
w_index
;
float
scale_data
=
qscale
[
scale_index
];
output
[
q_index
]
=
(
scale_data
*
(
float
)
q
);
tid
+=
blockDim
.
x
;
}
}
__global__
void
dequantize_kernel
(
__half
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned
merge_hidden
=
hidden_dim
>>
merge_count
;
unsigned
quantization_stride
=
(
merge_hidden
*
output_size
)
/
groups
;
unsigned
bid
=
blockIdx
.
x
;
unsigned
tid
=
threadIdx
.
x
;
while
(
tid
<
output_size
)
{
unsigned
w_index
=
bid
/
merge_hidden
;
unsigned
q_index
=
tid
+
bid
*
output_size
;
auto
q
=
input
[
q_index
];
unsigned
merge_hidden_total
=
w_index
*
merge_hidden
;
unsigned
scale_index
=
((((
bid
-
merge_hidden_total
)
+
tid
*
merge_hidden
)
/
quantization_stride
)
<<
merge_count
)
+
w_index
;
float
scale_data
=
qscale
[
scale_index
];
output
[
q_index
]
=
__float2half
(
scale_data
*
(
float
)
q
);
tid
+=
blockDim
.
x
;
}
#endif
}
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
unsigned
merge_count
,
cudaStream_t
stream
)
{
unsigned
threads
=
1024
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
hidden_dim
);
dequantize_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
output
,
input
,
qscale
,
output_size
,
hidden_dim
,
groups
,
merge_count
);
}
template
void
launch_dequantize
<
float
>(
float
*
,
const
int8_t
*
,
const
float
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
cudaStream_t
);
template
void
launch_dequantize
<
__half
>(
__half
*
,
const
int8_t
*
,
const
float
*
,
unsigned
,
unsigned
,
unsigned
,
unsigned
,
cudaStream_t
);
csrc/transformer_bak/inference/csrc/dequantize.hip
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
#endif
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
hipStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
csrc/transformer_bak/inference/csrc/gelu.cu
0 → 100644
View file @
7d1a83a9
#include "custom_cuda_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline
__device__
float
gelu
(
const
float
x
)
{
const
float
sqrt_param
=
0.79788456080286535587989211986876
f
;
const
float
mul_param
=
0.044715
;
return
x
*
0.5
f
*
(
1.0
f
+
tanhf
(
sqrt_param
*
(
x
+
mul_param
*
x
*
x
*
x
)));
}
__global__
void
fused_bias_gelu
(
float
*
input
,
const
float
*
bias
,
int
total_count
,
int
intermediate_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
intermediate_size
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
data
.
x
=
gelu
(
data
.
x
);
data
.
y
=
gelu
(
data
.
y
);
data
.
z
=
gelu
(
data
.
z
);
data
.
w
=
gelu
(
data
.
w
);
input_cast
[
offset
]
=
data
;
}
}
__global__
void
fused_bias_gelu
(
__half
*
input
,
const
__half
*
bias
,
int
total_count
,
int
intermediate_size
)
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
intermediate_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
low_data
.
x
=
gelu
(
low_data
.
x
);
low_data
.
y
=
gelu
(
low_data
.
y
);
high_data
.
x
=
gelu
(
high_data
.
x
);
high_data
.
y
=
gelu
(
high_data
.
y
);
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
input_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_gelu
(
T
*
input
,
const
T
*
bias
,
int
intermediate_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch_size
*
(
intermediate_size
/
4
);
int
threads
=
1024
;
// intermediate_size / iterations / 4;
dim3
block_dims
(
threads
);
dim3
grid_dims
(((
total_count
-
1
)
/
1024
+
1
));
// (batch_size);
fused_bias_gelu
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
total_count
,
intermediate_size
/
4
);
}
template
void
launch_bias_gelu
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_gelu
<
__half
>(
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
__global__
void
fused_bias_add
(
float
*
input
,
const
float
*
bias
,
int
total_count
,
int
hidden_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
const
float4
*
bias_cast
=
reinterpret_cast
<
const
float4
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
hidden_size
];
data
.
x
+=
bias_data
.
x
;
data
.
y
+=
bias_data
.
y
;
data
.
z
+=
bias_data
.
z
;
data
.
w
+=
bias_data
.
w
;
input_cast
[
offset
]
=
data
;
}
}
__global__
void
fused_bias_add
(
__half
*
input
,
const
__half
*
bias
,
int
total_count
,
int
hidden_size
)
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
const
float2
*
bias_cast
=
reinterpret_cast
<
const
float2
*>
(
bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
hidden_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
low_data
.
x
+=
low_bias
.
x
;
low_data
.
y
+=
low_bias
.
y
;
high_data
.
x
+=
high_bias
.
x
;
high_data
.
y
+=
high_bias
.
y
;
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
input_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_add
(
T
*
input
,
const
T
*
bias
,
int
hidden_size
,
int
batch_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch_size
*
(
hidden_size
/
4
);
int
threads
=
1024
;
// hidden_size / iterations / 4;
dim3
block_dims
(
threads
);
dim3
grid_dims
(((
total_count
-
1
)
/
threads
+
1
));
// (batch_size);
fused_bias_add
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
bias
,
total_count
,
hidden_size
/
4
);
}
template
void
launch_bias_add
<
float
>(
float
*
,
const
float
*
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_add
<
__half
>(
__half
*
,
const
__half
*
,
int
,
int
,
cudaStream_t
);
__global__
void
fused_bias_residual
(
float
*
input
,
float
*
output
,
float
*
attn
,
float
*
bias
,
float
*
attnbias
,
int
total_count
,
int
intermediate_size
,
int
mp_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
float4
*
output_cast
=
reinterpret_cast
<
float4
*>
(
output
);
float4
*
attn_cast
=
reinterpret_cast
<
float4
*>
(
attn
);
float4
*
bias_cast
=
reinterpret_cast
<
float4
*>
(
bias
);
float4
*
attnbias_cast
=
reinterpret_cast
<
float4
*>
(
attnbias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
out
=
output_cast
[
offset
];
float4
res_vec
=
attn_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
intermediate_size
];
float4
attn_bias
=
attnbias_cast
[
offset
%
intermediate_size
];
data
.
x
=
(
data
.
x
+
res_vec
.
x
)
*
mp_size
+
(
out
.
x
+
bias_data
.
x
+
attn_bias
.
x
);
data
.
y
=
(
data
.
y
+
res_vec
.
y
)
*
mp_size
+
(
out
.
y
+
bias_data
.
y
+
attn_bias
.
y
);
data
.
z
=
(
data
.
z
+
res_vec
.
z
)
*
mp_size
+
(
out
.
z
+
bias_data
.
z
+
attn_bias
.
z
);
data
.
w
=
(
data
.
w
+
res_vec
.
w
)
*
mp_size
+
(
out
.
w
+
bias_data
.
w
+
attn_bias
.
w
);
output_cast
[
offset
]
=
data
;
}
}
__global__
void
fused_bias_residual
(
__half
*
input
,
__half
*
output
,
__half
*
attn
,
__half
*
bias
,
__half
*
attn_bias
,
int
total_count
,
int
intermediate_size
,
int
mp_size
)
{
#ifdef HALF_PRECISION_AVAILABLE
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
float2
*
output_cast
=
reinterpret_cast
<
float2
*>
(
output
);
float2
*
attn_cast
=
reinterpret_cast
<
float2
*>
(
attn
);
float2
*
bias_cast
=
reinterpret_cast
<
float2
*>
(
bias
);
float2
*
attnbias_cast
=
reinterpret_cast
<
float2
*>
(
attn_bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
out_vec
=
output_cast
[
offset
];
float2
res_vec
=
attn_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
intermediate_size
];
float2
attn_bias_vec
=
attnbias_cast
[
offset
%
intermediate_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
out_half
=
reinterpret_cast
<
__half2
*>
(
&
out_vec
);
__half2
*
res_half
=
reinterpret_cast
<
__half2
*>
(
&
res_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
__half2
*
attnbias_half
=
reinterpret_cast
<
__half2
*>
(
&
attn_bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_out
=
__half22float2
(
out_half
[
0
]);
float2
high_out
=
__half22float2
(
out_half
[
1
]);
float2
low_res
=
__half22float2
(
res_half
[
0
]);
float2
high_res
=
__half22float2
(
res_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
float2
attn_low_bias
=
__half22float2
(
attnbias_half
[
0
]);
float2
attn_high_bias
=
__half22float2
(
attnbias_half
[
1
]);
low_data
.
x
=
(
low_data
.
x
+
low_res
.
x
)
*
mp_size
+
(
low_out
.
x
+
(
low_bias
.
x
+
attn_low_bias
.
x
));
low_data
.
y
=
(
low_data
.
y
+
low_res
.
y
)
*
mp_size
+
(
low_out
.
y
+
(
low_bias
.
y
+
attn_low_bias
.
y
));
high_data
.
x
=
(
high_data
.
x
+
high_res
.
x
)
*
mp_size
+
(
high_out
.
x
+
(
high_bias
.
x
+
attn_high_bias
.
x
));
high_data
.
y
=
(
high_data
.
y
+
high_res
.
y
)
*
mp_size
+
(
high_out
.
y
+
(
high_bias
.
y
+
attn_high_bias
.
y
));
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
output_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_bias_residual
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
batch
,
int
hidden_dim
,
int
mp_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch
*
hidden_dim
/
4
;
dim3
block_dims
(
1024
);
dim3
grid_dims
((
total_count
-
1
)
/
1024
+
1
);
// (batch_size);
fused_bias_residual
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
attn
,
bias
,
attn_bias
,
total_count
,
hidden_dim
/
4
,
1.0
/
mp_size
);
}
template
void
launch_bias_residual
<
float
>(
float
*
,
float
*
,
float
*
,
float
*
,
float
*
,
int
,
int
,
int
,
cudaStream_t
);
template
void
launch_bias_residual
<
__half
>(
__half
*
,
__half
*
,
__half
*
,
__half
*
,
__half
*
,
int
,
int
,
int
,
cudaStream_t
);
__global__
void
gptj_residual_add
(
float
*
input
,
float
*
output
,
float
*
attn
,
float
*
bias
,
float
*
attnbias
,
int
total_count
,
int
intermediate_size
,
float
mp_size
)
{
float4
*
input_cast
=
reinterpret_cast
<
float4
*>
(
input
);
float4
*
output_cast
=
reinterpret_cast
<
float4
*>
(
output
);
float4
*
attn_cast
=
reinterpret_cast
<
float4
*>
(
attn
);
float4
*
bias_cast
=
reinterpret_cast
<
float4
*>
(
bias
);
float4
*
attnbias_cast
=
reinterpret_cast
<
float4
*>
(
attnbias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float4
data
=
input_cast
[
offset
];
float4
out
=
output_cast
[
offset
];
float4
res_vec
=
attn_cast
[
offset
];
float4
bias_data
=
bias_cast
[
offset
%
intermediate_size
];
float4
attn_bias
=
attnbias_cast
[
offset
%
intermediate_size
];
data
.
x
=
data
.
x
*
mp_size
+
(
out
.
x
+
res_vec
.
x
+
bias_data
.
x
+
attn_bias
.
x
);
data
.
y
=
data
.
y
*
mp_size
+
(
out
.
y
+
res_vec
.
y
+
bias_data
.
y
+
attn_bias
.
y
);
data
.
z
=
data
.
z
*
mp_size
+
(
out
.
z
+
res_vec
.
z
+
bias_data
.
z
+
attn_bias
.
z
);
data
.
w
=
data
.
w
*
mp_size
+
(
out
.
w
+
res_vec
.
w
+
bias_data
.
w
+
attn_bias
.
w
);
output_cast
[
offset
]
=
data
;
}
}
__global__
void
gptj_residual_add
(
__half
*
input
,
__half
*
output
,
__half
*
attn
,
__half
*
bias
,
__half
*
attn_bias
,
int
total_count
,
int
intermediate_size
,
float
mp_size
)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2
*
input_cast
=
reinterpret_cast
<
float2
*>
(
input
);
float2
*
output_cast
=
reinterpret_cast
<
float2
*>
(
output
);
float2
*
attn_cast
=
reinterpret_cast
<
float2
*>
(
attn
);
float2
*
bias_cast
=
reinterpret_cast
<
float2
*>
(
bias
);
float2
*
attnbias_cast
=
reinterpret_cast
<
float2
*>
(
attn_bias
);
int
offset
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
total_count
)
{
float2
vals_vec
=
input_cast
[
offset
];
float2
out_vec
=
output_cast
[
offset
];
float2
res_vec
=
attn_cast
[
offset
];
float2
bias_vec
=
bias_cast
[
offset
%
intermediate_size
];
float2
attn_bias_vec
=
attnbias_cast
[
offset
%
intermediate_size
];
__half2
*
vals_half
=
reinterpret_cast
<
__half2
*>
(
&
vals_vec
);
__half2
*
out_half
=
reinterpret_cast
<
__half2
*>
(
&
out_vec
);
__half2
*
res_half
=
reinterpret_cast
<
__half2
*>
(
&
res_vec
);
__half2
*
bias_half
=
reinterpret_cast
<
__half2
*>
(
&
bias_vec
);
__half2
*
attnbias_half
=
reinterpret_cast
<
__half2
*>
(
&
attn_bias_vec
);
float2
low_data
=
__half22float2
(
vals_half
[
0
]);
float2
high_data
=
__half22float2
(
vals_half
[
1
]);
float2
low_out
=
__half22float2
(
out_half
[
0
]);
float2
high_out
=
__half22float2
(
out_half
[
1
]);
float2
low_res
=
__half22float2
(
res_half
[
0
]);
float2
high_res
=
__half22float2
(
res_half
[
1
]);
float2
low_bias
=
__half22float2
(
bias_half
[
0
]);
float2
high_bias
=
__half22float2
(
bias_half
[
1
]);
float2
attn_low_bias
=
__half22float2
(
attnbias_half
[
0
]);
float2
attn_high_bias
=
__half22float2
(
attnbias_half
[
1
]);
low_data
.
x
=
low_data
.
x
*
mp_size
+
(
low_out
.
x
+
low_res
.
x
+
(
low_bias
.
x
+
attn_low_bias
.
x
));
low_data
.
y
=
low_data
.
y
*
mp_size
+
(
low_out
.
y
+
low_res
.
y
+
(
low_bias
.
y
+
attn_low_bias
.
y
));
high_data
.
x
=
high_data
.
x
*
mp_size
+
(
high_out
.
x
+
high_res
.
x
+
(
high_bias
.
x
+
attn_high_bias
.
x
));
high_data
.
y
=
high_data
.
y
*
mp_size
+
(
high_out
.
y
+
high_res
.
y
+
(
high_bias
.
y
+
attn_high_bias
.
y
));
vals_half
[
0
]
=
__float22half2_rn
(
low_data
);
vals_half
[
1
]
=
__float22half2_rn
(
high_data
);
output_cast
[
offset
]
=
vals_vec
;
}
#endif
}
template
<
typename
T
>
void
launch_gptj_residual_add
(
T
*
input
,
T
*
output
,
T
*
attn
,
T
*
bias
,
T
*
attn_bias
,
int
hidden_dim
,
int
batch
,
int
mp_size
,
cudaStream_t
stream
)
{
int
total_count
=
batch
*
hidden_dim
/
4
;
dim3
block_dims
(
1024
);
dim3
grid_dims
((
total_count
-
1
)
/
1024
+
1
);
// (batch_size);
gptj_residual_add
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
input
,
output
,
attn
,
bias
,
attn_bias
,
total_count
,
hidden_dim
/
4
,
1.0
/
mp_size
);
}
template
void
launch_gptj_residual_add
<
float
>(
float
*
,
float
*
,
float
*
,
float
*
,
float
*
,
int
,
int
,
int
,
cudaStream_t
);
template
void
launch_gptj_residual_add
<
__half
>(
__half
*
,
__half
*
,
__half
*
,
__half
*
,
__half
*
,
int
,
int
,
int
,
cudaStream_t
);
__global__
void
moe_res_matmul
(
float
*
residual
,
float
*
coef
,
float
*
mlp_out
,
int
seq_len
,
int
hidden_dim
)
{
unsigned
tid
=
threadIdx
.
x
;
float4
*
residual_cast
=
reinterpret_cast
<
float4
*>
(
residual
);
float4
*
coef_cast
=
reinterpret_cast
<
float4
*>
(
coef
);
float4
*
mlp_out_cast
=
reinterpret_cast
<
float4
*>
(
mlp_out
);
residual_cast
+=
blockIdx
.
x
*
hidden_dim
;
mlp_out_cast
+=
blockIdx
.
x
*
hidden_dim
;
float4
*
coef_cast2
=
coef_cast
+
hidden_dim
;
while
(
tid
<
hidden_dim
)
{
float4
res
=
residual_cast
[
tid
];
float4
mlp
=
mlp_out_cast
[
tid
];
float4
coef1
=
coef_cast
[
tid
];
float4
coef2
=
coef_cast2
[
tid
];
mlp
.
x
=
mlp
.
x
*
coef2
.
x
+
res
.
x
*
coef1
.
x
;
mlp
.
y
=
mlp
.
y
*
coef2
.
y
+
res
.
y
*
coef1
.
y
;
mlp
.
z
=
mlp
.
z
*
coef2
.
z
+
res
.
z
*
coef1
.
z
;
mlp
.
w
=
mlp
.
w
*
coef2
.
w
+
res
.
w
*
coef1
.
w
;
mlp_out_cast
[
tid
]
=
mlp
;
tid
+=
blockDim
.
x
;
}
}
__global__
void
moe_res_matmul
(
__half
*
residual
,
__half
*
coef
,
__half
*
mlp_out
,
int
seq_len
,
int
hidden_dim
)
{
unsigned
tid
=
threadIdx
.
x
;
float2
*
residual_cast
=
reinterpret_cast
<
float2
*>
(
residual
);
float2
*
mlp_out_cast
=
reinterpret_cast
<
float2
*>
(
mlp_out
);
float2
*
coef_cast
=
reinterpret_cast
<
float2
*>
(
coef
);
float2
*
coef_cast2
=
coef_cast
+
hidden_dim
;
residual_cast
+=
blockIdx
.
x
*
hidden_dim
;
mlp_out_cast
+=
blockIdx
.
x
*
hidden_dim
;
while
(
tid
<
hidden_dim
)
{
float2
res
=
residual_cast
[
tid
];
float2
coef1
=
coef_cast
[
tid
];
float2
coef2
=
coef_cast
[
tid
];
float2
data
=
mlp_out_cast
[
tid
];
__half
*
data_h
=
reinterpret_cast
<
__half
*>
(
&
data
);
__half
*
coef1_h
=
reinterpret_cast
<
__half
*>
(
&
coef1
);
__half
*
coef2_h
=
reinterpret_cast
<
__half
*>
(
&
coef2
);
__half
*
res_h
=
reinterpret_cast
<
__half
*>
(
&
res
);
data_h
[
0
]
=
res_h
[
0
]
*
coef1_h
[
0
]
+
data_h
[
0
]
*
coef2_h
[
0
];
data_h
[
1
]
=
res_h
[
1
]
*
coef1_h
[
1
]
+
data_h
[
1
]
*
coef2_h
[
1
];
data_h
[
2
]
=
res_h
[
2
]
*
coef1_h
[
2
]
+
data_h
[
2
]
*
coef2_h
[
2
];
data_h
[
3
]
=
res_h
[
3
]
*
coef1_h
[
3
]
+
data_h
[
3
]
*
coef2_h
[
3
];
mlp_out_cast
[
tid
]
=
data
;
tid
+=
blockDim
.
x
;
}
}
template
<
typename
T
>
void
launch_moe_res_matmul
(
T
*
residual
,
T
*
coef
,
T
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
stream
)
{
dim3
grid_dim
(
seq_len
);
dim3
block_dim
(
1024
);
moe_res_matmul
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
residual
,
coef
,
mlp_out
,
seq_len
,
hidden_dim
/
4
);
}
template
void
launch_moe_res_matmul
(
float
*
residual
,
float
*
coef
,
float
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
stream
);
template
void
launch_moe_res_matmul
(
__half
*
residual
,
__half
*
coef
,
__half
*
mlp_out
,
int
seq_len
,
int
hidden_dim
,
cudaStream_t
stream
);
csrc/transformer_bak/inference/csrc/gelu.hip
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
__global__ void fused_bias_gelu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
input_cast[offset] = data;
}
}
__global__ void fused_bias_gelu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size / 4);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
input_cast[offset] = data;
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, hipStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_add), dim3(grid_dims), dim3(block_dims), 0, stream, input, bias, total_count, hidden_size / 4);
}
template void launch_bias_add<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x =
(high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y =
(high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( fused_bias_residual), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void
launch_bias_residual<float>(float*, float*, float*, float*, float*, int, int, int, hipStream_t);
template void launch_bias_residual<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_size)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
low_data.y =
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
high_data.x =
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
high_data.y =
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( gptj_residual_add), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
hipStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float4* residual_cast = reinterpret_cast<float4*>(residual);
float4* coef_cast = reinterpret_cast<float4*>(coef);
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
float4* coef_cast2 = coef_cast + hidden_dim;
while (tid < hidden_dim) {
float4 res = residual_cast[tid];
float4 mlp = mlp_out_cast[tid];
float4 coef1 = coef_cast[tid];
float4 coef2 = coef_cast2[tid];
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
mlp_out_cast[tid] = mlp;
tid += blockDim.x;
}
}
__global__ void moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float2* residual_cast = reinterpret_cast<float2*>(residual);
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
float2* coef_cast = reinterpret_cast<float2*>(coef);
float2* coef_cast2 = coef_cast + hidden_dim;
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
while (tid < hidden_dim) {
float2 res = residual_cast[tid];
float2 coef1 = coef_cast[tid];
float2 coef2 = coef_cast[tid];
float2 data = mlp_out_cast[tid];
__half* data_h = reinterpret_cast<__half*>(&data);
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
__half* res_h = reinterpret_cast<__half*>(&res);
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];
mlp_out_cast[tid] = data;
tid += blockDim.x;
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
hipLaunchKernelGGL(( moe_res_matmul), dim3(grid_dim), dim3(block_dim), 0, stream,
residual, coef, mlp_out, seq_len, hidden_dim / 4);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
csrc/transformer_bak/inference/csrc/normalize.cu
0 → 100644
View file @
7d1a83a9
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace
cg
=
cooperative_groups
;
__global__
void
fused_bias_residual_layer_norm
(
float
*
output
,
const
float
*
vals
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
row_stride
)
{
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
float
inp_reg
[
NORM_REG
];
int
k
=
0
;
float
sum
=
0
;
int
input_id
=
id
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
]
=
vals
[
input_id
+
row
*
row_stride
];
sum
+=
inp_reg
[
k
++
];
input_id
+=
iteration_stride
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
inp_reg
[
f
]
-=
mean
;
sum
+=
inp_reg
[
f
]
*
inp_reg
[
f
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
sum
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma
[
out_id
]
+
beta
[
out_id
];
output
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
}
__global__
void
fused_bias_residual_layer_norm
(
__half
*
output
,
const
__half
*
vals
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
row_stride
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
iteration_stride
=
blockDim
.
x
;
int
iterations
=
row_stride
/
iteration_stride
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
__half2
inp_reg
[
NORM_REG
];
const
__half2
*
vals_cast
=
reinterpret_cast
<
const
__half2
*>
(
vals
);
__half2
*
out_cast
=
reinterpret_cast
<
__half2
*>
(
output
);
int
k
=
0
;
int
input_id
=
id
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
++
]
=
vals_cast
[
input_id
+
row
*
row_stride
];
input_id
+=
iteration_stride
;
}
float
sum
=
0
;
for
(
int
f
=
k
-
1
;
f
>=
0
;
f
--
)
{
float2
inp_f
=
__half22float2
(
inp_reg
[
f
]);
sum
+=
inp_f
.
x
+
inp_f
.
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
<<
1
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
float2
inp_f
=
__half22float2
(
inp_reg
[
f
]);
inp_f
.
x
-=
mean
;
inp_f
.
y
-=
mean
;
inp_reg
[
f
]
=
__float22half2_rn
(
inp_f
);
sum
+=
inp_f
.
x
*
inp_f
.
x
;
sum
+=
inp_f
.
y
*
inp_f
.
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
<<
1
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
__half2
variance_h
=
__float2half2_rn
(
sum
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
variance_h
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma_cast
[
out_id
]
+
beta_cast
[
out_id
];
out_cast
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
#endif
}
template
<
typename
T
>
void
launch_layer_norm
(
T
*
out
,
T
*
vals
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
);
template
<
>
void
launch_layer_norm
<
float
>
(
float
*
out
,
float
*
vals
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
vals
,
gamma
,
beta
,
epsilon
,
hidden_dim
);
}
template
<
>
void
launch_layer_norm
<
__half
>
(
__half
*
out
,
__half
*
vals
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_bias_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
out
,
vals
,
gamma
,
beta
,
epsilon
,
hidden_dim
/
2
);
}
__global__
void
fused_residual_layer_norm
(
float
*
norm
,
float
*
res_add
,
float
*
vals
,
float
*
residual
,
const
float
*
bias
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
row_stride
,
bool
preLN
,
bool
mlp_after_attn
)
{
int
iteration_stride
=
blockDim
.
x
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
float
inp_reg
[
NORM_REG
];
int
k
=
0
;
int
input_id
=
id
;
float
sum
=
0
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
]
=
vals
[
input_id
+
row
*
row_stride
];
float
res_f
=
(
residual
[
input_id
+
row
*
row_stride
]);
float
bias_f
=
(
bias
[
input_id
]);
if
(
mlp_after_attn
)
inp_reg
[
k
]
+=
res_f
+
bias_f
;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum
+=
inp_reg
[
k
++
];
input_id
+=
iteration_stride
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
inp_reg
[
f
]
-=
mean
;
sum
+=
inp_reg
[
f
]
*
inp_reg
[
f
];
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
sum
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma
[
out_id
]
+
beta
[
out_id
];
norm
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
}
__global__
void
fused_residual_layer_norm
(
__half
*
norm
,
__half
*
res_add
,
__half
*
vals
,
__half
*
residual
,
const
__half
*
bias
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
row_stride
,
bool
preLN
,
bool
mlp_after_attn
)
{
#ifdef HALF_PRECISION_AVAILABLE
int
iteration_stride
=
blockDim
.
x
;
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
32
>
g
=
cg
::
tiled_partition
<
32
>
(
b
);
int
row
=
blockIdx
.
x
;
int
id
=
threadIdx
.
x
;
int
gid
=
id
>>
5
;
int
warp_num
=
iteration_stride
>>
5
;
__half2
inp_reg
[
NORM_REG
];
__half2
*
vals_cast
=
reinterpret_cast
<
__half2
*>
(
vals
);
__half2
*
norm_cast
=
reinterpret_cast
<
__half2
*>
(
norm
);
__half2
*
res_add_cast
=
reinterpret_cast
<
__half2
*>
(
res_add
);
__half2
*
residual_cast
=
reinterpret_cast
<
__half2
*>
(
residual
);
const
__half2
*
bias_cast
=
reinterpret_cast
<
const
__half2
*>
(
bias
);
int
k
=
0
;
int
input_id
=
id
;
float
sum
=
0
;
while
(
input_id
<
row_stride
)
{
inp_reg
[
k
]
=
vals_cast
[
input_id
+
row
*
row_stride
];
float2
inp_f
=
__half22float2
(
inp_reg
[
k
]);
float2
res_f
=
__half22float2
(
residual_cast
[
input_id
+
row
*
row_stride
]);
float2
bias_f
=
__half22float2
(
bias_cast
[
input_id
]);
if
(
mlp_after_attn
)
{
inp_f
.
x
+=
res_f
.
x
+
bias_f
.
x
;
inp_f
.
y
+=
res_f
.
y
+
bias_f
.
y
;
}
inp_reg
[
k
]
=
__float22half2_rn
(
inp_f
);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum
+=
inp_f
.
x
+
inp_f
.
y
;
input_id
+=
iteration_stride
;
k
++
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
__shared__
float
shr
[
MAX_WARP_NUM
];
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
float
mean
=
sum
/
(
row_stride
<<
1
);
sum
=
0.
f
;
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
float2
inp_f
=
__half22float2
(
inp_reg
[
f
]);
inp_f
.
x
-=
mean
;
inp_f
.
y
-=
mean
;
inp_reg
[
f
]
=
__float22half2_rn
(
inp_f
);
sum
+=
inp_f
.
x
*
inp_f
.
x
;
sum
+=
inp_f
.
y
*
inp_f
.
y
;
}
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
if
(
g
.
thread_rank
()
==
0
)
shr
[
gid
]
=
sum
;
b
.
sync
();
if
(
g
.
thread_rank
()
<
(
warp_num
))
sum
=
shr
[
g
.
thread_rank
()];
b
.
sync
();
for
(
int
i
=
1
;
i
<
(
warp_num
);
i
*=
2
)
sum
+=
g
.
shfl_down
(
sum
,
i
);
sum
=
g
.
shfl
(
sum
,
0
);
sum
/=
(
row_stride
<<
1
);
sum
+=
epsilon
;
sum
=
__frsqrt_rn
(
sum
);
__half2
variance_h
=
__float2half2_rn
(
sum
);
const
__half2
*
gamma_cast
=
reinterpret_cast
<
const
__half2
*>
(
gamma
);
const
__half2
*
beta_cast
=
reinterpret_cast
<
const
__half2
*>
(
beta
);
for
(
int
f
=
0
;
f
<
k
;
f
++
)
{
int
out_id
=
f
*
iteration_stride
+
id
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
variance_h
;
inp_reg
[
f
]
=
inp_reg
[
f
]
*
gamma_cast
[
out_id
]
+
beta_cast
[
out_id
];
norm_cast
[
out_id
+
row
*
row_stride
]
=
inp_reg
[
f
];
}
#endif
}
template
<
typename
T
>
void
launch_residual_layer_norm
(
T
*
norm
,
T
*
res_add
,
T
*
vals
,
T
*
residual
,
const
T
*
bias
,
const
T
*
gamma
,
const
T
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
cudaStream_t
stream
);
template
<
>
void
launch_residual_layer_norm
<
float
>
(
float
*
norm
,
float
*
res_add
,
float
*
vals
,
float
*
residual
,
const
float
*
bias
,
const
float
*
gamma
,
const
float
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
norm
,
res_add
,
vals
,
residual
,
bias
,
gamma
,
beta
,
epsilon
,
hidden_dim
,
preLN
,
mlp_after_attn
);
}
template
<
>
void
launch_residual_layer_norm
<
__half
>
(
__half
*
norm
,
__half
*
res_add
,
__half
*
vals
,
__half
*
residual
,
const
__half
*
bias
,
const
__half
*
gamma
,
const
__half
*
beta
,
float
epsilon
,
int
batch_size
,
int
hidden_dim
,
bool
preLN
,
bool
mlp_after_attn
,
cudaStream_t
stream
)
{
constexpr
int
threads
=
1024
;
dim3
grid_dim
(
batch_size
);
dim3
block_dim
(
threads
);
fused_residual_layer_norm
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
norm
,
res_add
,
vals
,
residual
,
bias
,
gamma
,
beta
,
epsilon
,
hidden_dim
/
2
,
preLN
,
mlp_after_attn
);
}
csrc/transformer_bak/inference/csrc/normalize.hip
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
csrc/transformer_bak/inference/csrc/pt_binding.cpp
0 → 100644
View file @
7d1a83a9
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
std
::
array
<
int
,
3
>
gemm_algos
=
std
::
array
<
int
,
3
>
({
99
,
99
,
99
});
#define MAX_OUT_TOKES 10
template
<
typename
T
>
at
::
Tensor
ds_softmax
(
at
::
Tensor
&
attn_scores
,
at
::
Tensor
&
attn_mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
bool
async_op
)
{
auto
attn_scores_c
=
attn_scores
.
contiguous
();
int
bsz
=
attn_scores_c
.
size
(
0
);
int
seq_len
=
attn_scores_c
.
size
(
1
);
int
len
=
attn_scores_c
.
sizes
().
size
();
if
(
len
>
3
)
seq_len
=
attn_scores_c
.
size
(
2
);
int
soft_len
=
attn_scores_c
.
size
(
2
);
if
(
len
>
3
)
soft_len
=
attn_scores_c
.
size
(
3
);
int
heads
=
1
;
if
(
len
>
3
)
heads
=
attn_scores_c
.
size
(
1
);
launch_attn_softmax_v2
((
T
*
)
attn_scores_c
.
data_ptr
(),
(
attn_mask
.
sizes
().
size
()
>
1
?
(
T
*
)
attn_mask
.
data_ptr
()
:
nullptr
),
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
heads
,
seq_len
,
soft_len
,
1.0
,
Context
::
Instance
().
GetCurrentStream
(
async_op
));
return
attn_scores_c
;
}
template
<
typename
T
>
void
allocate_workspace
(
size_t
hidden_dim
,
size_t
max_seq_len
,
size_t
batch_size
,
size_t
head_size
=
128
)
{
size_t
_workSpaceSize
=
(
hidden_dim
*
batch_size
*
max_seq_len
);
Context
::
Instance
().
GenWorkSpace
(
_workSpaceSize
*
sizeof
(
T
));
}
template
<
typename
T
>
at
::
Tensor
einsum_sec_sm_ecm
(
at
::
Tensor
&
Q
,
at
::
Tensor
&
W
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
Q
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
float
alpha
=
1
;
float
gemm_beta
=
0.0
;
if
(
!
workspace
)
{
allocate_workspace
<
T
>
(
W
.
size
(
1
),
MAX_OUT_TOKES
,
Q
.
size
(
0
));
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
}
auto
O
=
at
::
from_blob
(
workspace
,
{
Q
.
size
(
1
),
Q
.
size
(
2
),
W
.
size
(
1
)},
options
);
unsigned
m
=
W
.
size
(
1
);
unsigned
n
=
Q
.
size
(
1
)
*
Q
.
size
(
2
);
unsigned
k
=
Q
.
size
(
0
);
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_T
,
m
,
n
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
W
.
data_ptr
(),
(
T
*
)
Q
.
data_ptr
(),
(
T
*
)
O
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
O
;
}
template
<
typename
T
>
void
attention_unfused
(
at
::
Tensor
&
prev_key_cont
,
at
::
Tensor
&
query_cont
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value_cont
,
at
::
Tensor
&
output
,
int
&
bsz
,
int
&
seq_len
,
int
&
soft_len
,
int
&
heads
,
float
&
norm_factor
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
float
alpha
=
norm_factor
;
float
gemm_beta
=
0.0
;
auto
attn_score
=
at
::
empty
({
bsz
,
heads
,
seq_len
,
soft_len
},
options
);
int
k
=
prev_value_cont
.
size
(
2
)
/
heads
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
soft_len
,
seq_len
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_key_cont
.
data_ptr
(),
(
T
*
)
query_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
soft_len
*
k
,
seq_len
*
k
,
seq_len
*
soft_len
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
attn_score
=
ds_softmax
<
T
>
(
attn_score
,
attn_mask
,
triangular
,
recompute
,
local_attention
,
window_size
,
false
);
alpha
=
1.0
;
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
k
,
seq_len
,
soft_len
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_value_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
soft_len
*
k
,
seq_len
*
soft_len
,
seq_len
*
k
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_softmax_context
(
at
::
Tensor
&
query
,
at
::
Tensor
&
prev_key
,
at
::
Tensor
&
new_key
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value
,
at
::
Tensor
&
new_value
,
int
heads
,
float
norm_factor
,
bool
merging
,
bool
triangular
,
bool
local_attention
,
int
window_size
,
bool
no_masking
)
{
auto
query_cont
=
query
.
contiguous
();
auto
prev_key_cont
=
prev_key
.
contiguous
();
auto
prev_value_cont
=
prev_value
.
contiguous
();
int
new_size
=
(
new_value
.
sizes
().
size
()
>
1
?
new_value
.
size
(
1
)
:
0
);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int
bsz
=
query_cont
.
size
(
0
);
int
seq_len
=
query_cont
.
size
(
1
);
int
soft_len
=
prev_value
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
prev_value
.
size
(
0
),
heads
,
seq_len
,
prev_value
.
size
(
2
)
/
heads
},
options
);
attention_unfused
<
T
>
(
prev_key_cont
,
query_cont
,
attn_mask
,
//(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont
,
output
,
bsz
,
seq_len
,
soft_len
,
heads
,
norm_factor
,
(
triangular
&&
(
new_size
==
0
)),
(
new_size
==
0
),
local_attention
,
window_size
);
return
{
output
,
prev_key
,
prev_value
};
}
template
<
typename
T
>
at
::
Tensor
ds_bias_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
int
intermediate_size
=
input_cont
.
size
(
2
);
launch_bias_gelu
((
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
intermediate_size
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_bias_residual
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
residual_cont
=
residual
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_layernorm
(
at
::
Tensor
&
input_cont
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
betta
,
float
epsilon
)
{
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
launch_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
betta
.
data_ptr
(),
epsilon
,
bsz
,
input_cont
.
size
(
2
),
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
at
::
Tensor
qkv_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
inp_norm
=
ds_layernorm
<
T
>
(
input
,
gamma
,
beta
,
epsilon
);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_qkv_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
qkv_unfused_cublas
<
T
>
(
output
,
input_cont
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
add_bias
);
return
{
output
,
inp_norm
};
}
template
<
typename
T
>
void
quantized_gemm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
qscale
,
int
groups
,
int
merge_count
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
weight16
=
at
::
empty
({
weight
.
size
(
0
),
weight
.
size
(
1
)},
options
);
launch_dequantize
((
T
*
)
weight16
.
data_ptr
(),
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
1
),
weight
.
size
(
0
),
groups
,
merge_count
,
Context
::
Instance
().
GetCurrentStream
());
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight16
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
at
::
Tensor
ds_qkv_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
add_bias
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
inp_norm
=
ds_layernorm
<
T
>
(
input_cont
,
gamma
,
beta
,
epsilon
);
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
q_scale
,
int
groups
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
(
async_op
));
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
int
groups
,
int
merge_count
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
merge_count
);
return
output
;
}
template
<
typename
T
>
void
mlp_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input
);
launch_residual_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
nullptr
,
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
residual
.
data_ptr
(),
(
T
*
)
input_bias
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
beta
.
data_ptr
(),
epsilon
,
bsz
,
input
.
size
(
2
),
preLayerNorm
,
mlp_after_attn
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
template
<
typename
T
>
at
::
Tensor
ds_mlp_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
mlp_unfused_cublas
<
T
>
(
output
,
mlp_after_attn
?
input
:
residual
,
residual
,
input_bias
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
mlp_after_attn
);
return
output
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_mlp_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
auto
residual_add
=
(
preLayerNorm
?
at
::
empty_like
(
input_cont
)
:
inp_norm
);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
{
output
,
residual_add
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
weight_out
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
intermediate
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight_out
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight_out
.
size
(
1
),
bsz
,
intermediate
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight_out
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return
output
;
}
void
residual_add_bias
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
attention_output
,
at
::
Tensor
&
output_b
,
at
::
Tensor
&
attention_b
,
int
mp_size
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
int
hidden_size
=
input
.
size
(
2
);
// cudaStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
if
(
mlp_after_attn
)
launch_bias_residual
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
float
>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
if
(
mlp_after_attn
)
launch_bias_residual
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
__half
>
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
}
std
::
vector
<
at
::
Tensor
>
apply_rotary_pos_emb
(
at
::
Tensor
&
mixed_query
,
at
::
Tensor
&
key_layer
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
bool
rotate_half
,
bool
rotate_every_two
)
{
auto
query_cont
=
mixed_query
.
contiguous
();
auto
key_cont
=
key_layer
.
contiguous
();
unsigned
bsz
=
mixed_query
.
size
(
0
);
unsigned
head_size
=
mixed_query
.
size
(
2
)
/
num_heads
;
unsigned
seq_len
=
mixed_query
.
size
(
1
);
if
(
mixed_query
.
scalar_type
()
==
at
::
kFloat
)
launch_apply_rotary_pos_emb
<
float
>
((
float
*
)
query_cont
.
data_ptr
(),
(
float
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_apply_rotary_pos_emb
<
__half
>
((
__half
*
)
query_cont
.
data_ptr
(),
(
__half
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
return
{
query_cont
,
key_cont
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
at
::
Tensor
moe_res_matmul
(
at
::
Tensor
&
moe_res
,
at
::
Tensor
&
coef
,
at
::
Tensor
&
output
)
{
int
M
=
moe_res
.
size
(
0
)
*
moe_res
.
size
(
1
);
int
N
=
moe_res
.
size
(
2
);
Context
::
Instance
().
SynchComm
();
if
(
moe_res
.
scalar_type
()
==
at
::
kFloat
)
{
launch_moe_res_matmul
<
float
>
((
float
*
)
moe_res
.
data_ptr
(),
(
float
*
)
coef
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
cuda
::
getCurrentCUDAStream
());
}
else
{
launch_moe_res_matmul
<
__half
>
((
__half
*
)
moe_res
.
data_ptr
(),
(
__half
*
)
coef
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
cuda
::
getCurrentCUDAStream
());
}
return
output
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"softmax_fp32"
,
&
ds_softmax
<
float
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_fp16"
,
&
ds_softmax
<
__half
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp32"
,
&
ds_softmax_context
<
float
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp16"
,
&
ds_softmax_context
<
__half
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp32"
,
&
ds_bias_gelu
<
float
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp16"
,
&
ds_bias_gelu
<
__half
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp32"
,
&
ds_bias_residual
<
float
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp16"
,
&
ds_bias_residual
<
__half
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp32"
,
&
ds_layernorm
<
float
>
,
"DeepSpeed layer-norm with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp16"
,
&
ds_layernorm
<
__half
>
,
"DeepSpeed layer-norm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp32"
,
&
ds_qkv_gemm
<
float
>
,
"DeepSpeed qkv gemm with fp32 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp16"
,
&
ds_qkv_gemm
<
__half
>
,
"DeepSpeed qkv gemm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_int8"
,
&
ds_qkv_gemm_int8
<
__half
>
,
"DeepSpeed qkv gemm with int8 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp32"
,
&
ds_mlp_gemm
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp16"
,
&
ds_mlp_gemm
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"mlp_gemm_int8"
,
&
ds_mlp_gemm_int8
<
__half
>
,
"DeepSpeed mlp with int8 (CUDA)"
);
m
.
def
(
"vector_matmul_fp32"
,
&
ds_vector_matmul
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"vector_matmul_fp16"
,
&
ds_vector_matmul
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"vector_matmul_int8"
,
&
ds_vector_matmul_int8
<
__half
>
,
"DeepSpeed vector-MM with int8 (CUDA)"
);
m
.
def
(
"linear_layer_fp32"
,
&
ds_linear_layer
<
float
>
,
"DeepSpeed linear_layer with fp32 (CUDA)"
);
m
.
def
(
"linear_layer_fp16"
,
&
ds_linear_layer
<
__half
>
,
"DeepSpeed linear_layer with fp16 (CUDA)"
);
m
.
def
(
"linear_layer_int8"
,
&
ds_linear_layer_int8
<
__half
>
,
"DeepSpeed linear_layer with int8 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp32"
,
&
fused_gemm_gelu
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp16"
,
&
fused_gemm_gelu
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"residual_add"
,
&
residual_add_bias
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"apply_rotary_pos_emb"
,
&
apply_rotary_pos_emb
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp32"
,
&
einsum_sec_sm_ecm
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp16"
,
&
einsum_sec_sm_ecm
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"moe_res_matmul"
,
&
moe_res_matmul
,
"DeepSpeed moe residual matmul (CUDA)"
);
}
csrc/transformer_bak/inference/csrc/pt_binding_hip.cpp
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
std
::
array
<
int
,
3
>
gemm_algos
=
std
::
array
<
int
,
3
>
({
99
,
99
,
99
});
#define MAX_OUT_TOKES 10
template
<
typename
T
>
at
::
Tensor
ds_softmax
(
at
::
Tensor
&
attn_scores
,
at
::
Tensor
&
attn_mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
bool
async_op
)
{
auto
attn_scores_c
=
attn_scores
.
contiguous
();
int
bsz
=
attn_scores_c
.
size
(
0
);
int
seq_len
=
attn_scores_c
.
size
(
1
);
int
len
=
attn_scores_c
.
sizes
().
size
();
if
(
len
>
3
)
seq_len
=
attn_scores_c
.
size
(
2
);
int
soft_len
=
attn_scores_c
.
size
(
2
);
if
(
len
>
3
)
soft_len
=
attn_scores_c
.
size
(
3
);
int
heads
=
1
;
if
(
len
>
3
)
heads
=
attn_scores_c
.
size
(
1
);
launch_attn_softmax_v2
((
T
*
)
attn_scores_c
.
data_ptr
(),
(
attn_mask
.
sizes
().
size
()
>
1
?
(
T
*
)
attn_mask
.
data_ptr
()
:
nullptr
),
triangular
,
recompute
,
local_attention
,
window_size
,
bsz
,
heads
,
seq_len
,
soft_len
,
1.0
,
Context
::
Instance
().
GetCurrentStream
(
async_op
));
return
attn_scores_c
;
}
template
<
typename
T
>
void
allocate_workspace
(
size_t
hidden_dim
,
size_t
max_seq_len
,
size_t
batch_size
,
size_t
head_size
=
128
)
{
size_t
_workSpaceSize
=
(
hidden_dim
*
batch_size
*
max_seq_len
);
Context
::
Instance
().
GenWorkSpace
(
_workSpaceSize
*
sizeof
(
T
));
}
template
<
typename
T
>
at
::
Tensor
einsum_sec_sm_ecm
(
at
::
Tensor
&
Q
,
at
::
Tensor
&
W
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
Q
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
float
alpha
=
1
;
float
gemm_beta
=
0.0
;
if
(
!
workspace
)
{
allocate_workspace
<
T
>
(
W
.
size
(
1
),
MAX_OUT_TOKES
,
Q
.
size
(
0
));
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
}
auto
O
=
at
::
from_blob
(
workspace
,
{
Q
.
size
(
1
),
Q
.
size
(
2
),
W
.
size
(
1
)},
options
);
unsigned
m
=
W
.
size
(
1
);
unsigned
n
=
Q
.
size
(
1
)
*
Q
.
size
(
2
);
unsigned
k
=
Q
.
size
(
0
);
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_transpose
,
m
,
n
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
W
.
data_ptr
(),
(
T
*
)
Q
.
data_ptr
(),
(
T
*
)
O
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
O
;
}
template
<
typename
T
>
void
attention_unfused
(
at
::
Tensor
&
prev_key_cont
,
at
::
Tensor
&
query_cont
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value_cont
,
at
::
Tensor
&
output
,
int
&
bsz
,
int
&
seq_len
,
int
&
soft_len
,
int
&
heads
,
float
&
norm_factor
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
)
{
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
float
alpha
=
norm_factor
;
float
gemm_beta
=
0.0
;
auto
attn_score
=
at
::
empty
({
bsz
,
heads
,
seq_len
,
soft_len
},
options
);
int
k
=
prev_value_cont
.
size
(
2
)
/
heads
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
soft_len
,
seq_len
,
k
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_key_cont
.
data_ptr
(),
(
T
*
)
query_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
rocblas_operation_none
,
rocblas_operation_none
,
soft_len
*
k
,
seq_len
*
k
,
seq_len
*
soft_len
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
attn_score
=
ds_softmax
<
T
>
(
attn_score
,
attn_mask
,
triangular
,
recompute
,
local_attention
,
window_size
,
false
);
alpha
=
1.0
;
cublas_strided_batched_gemm
(
Context
::
Instance
().
GetCublasHandle
(),
k
,
seq_len
,
soft_len
,
&
alpha
,
&
gemm_beta
,
(
T
*
)
prev_value_cont
.
data_ptr
(),
(
T
*
)
attn_score
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
rocblas_operation_none
,
rocblas_operation_none
,
soft_len
*
k
,
seq_len
*
soft_len
,
seq_len
*
k
,
bsz
*
heads
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_softmax_context
(
at
::
Tensor
&
query
,
at
::
Tensor
&
prev_key
,
at
::
Tensor
&
new_key
,
at
::
Tensor
&
attn_mask
,
at
::
Tensor
&
prev_value
,
at
::
Tensor
&
new_value
,
int
heads
,
float
norm_factor
,
bool
merging
,
bool
triangular
,
bool
local_attention
,
int
window_size
,
bool
no_masking
)
{
auto
query_cont
=
query
.
contiguous
();
auto
prev_key_cont
=
prev_key
.
contiguous
();
auto
prev_value_cont
=
prev_value
.
contiguous
();
int
new_size
=
(
new_value
.
sizes
().
size
()
>
1
?
new_value
.
size
(
1
)
:
0
);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int
bsz
=
query_cont
.
size
(
0
);
int
seq_len
=
query_cont
.
size
(
1
);
int
soft_len
=
prev_value
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
query_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
prev_value
.
size
(
0
),
heads
,
seq_len
,
prev_value
.
size
(
2
)
/
heads
},
options
);
attention_unfused
<
T
>
(
prev_key_cont
,
query_cont
,
attn_mask
,
//(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont
,
output
,
bsz
,
seq_len
,
soft_len
,
heads
,
norm_factor
,
(
triangular
&&
(
new_size
==
0
)),
(
new_size
==
0
),
local_attention
,
window_size
);
return
{
output
,
prev_key
,
prev_value
};
}
template
<
typename
T
>
at
::
Tensor
ds_bias_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
int
intermediate_size
=
input_cont
.
size
(
2
);
launch_bias_gelu
((
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
intermediate_size
,
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_bias_residual
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
residual_cont
=
residual
.
contiguous
();
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return
input_cont
;
}
template
<
typename
T
>
at
::
Tensor
ds_layernorm
(
at
::
Tensor
&
input_cont
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
betta
,
float
epsilon
)
{
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
launch_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
betta
.
data_ptr
(),
epsilon
,
bsz
,
input_cont
.
size
(
2
),
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
at
::
Tensor
qkv_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
inp_norm
=
ds_layernorm
<
T
>
(
input
,
gamma
,
beta
,
epsilon
);
// hipEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
inp_norm
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_qkv_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
qkv_unfused_cublas
<
T
>
(
output
,
input_cont
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
add_bias
);
return
{
output
,
inp_norm
};
}
template
<
typename
T
>
void
quantized_gemm
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
qscale
,
int
groups
,
int
merge_count
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
weight16
=
at
::
empty
({
weight
.
size
(
0
),
weight
.
size
(
1
)},
options
);
launch_dequantize
((
T
*
)
weight16
.
data_ptr
(),
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
1
),
weight
.
size
(
0
),
groups
,
merge_count
,
Context
::
Instance
().
GetCurrentStream
());
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight16
.
data_ptr
(),
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
at
::
Tensor
ds_qkv_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
add_bias
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
inp_norm
=
ds_layernorm
<
T
>
(
input_cont
,
gamma
,
beta
,
epsilon
);
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_linear_layer_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
q_scale
,
int
groups
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
(
async_op
));
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
return
output
;
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
int
groups
,
int
merge_count
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
merge_count
);
return
output
;
}
template
<
typename
T
>
void
mlp_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input
);
launch_residual_layer_norm
((
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
nullptr
,
(
T
*
)
input
.
data_ptr
(),
(
T
*
)
residual
.
data_ptr
(),
(
T
*
)
input_bias
.
data_ptr
(),
(
T
*
)
gamma
.
data_ptr
(),
(
T
*
)
beta
.
data_ptr
(),
epsilon
,
bsz
,
input
.
size
(
2
),
preLayerNorm
,
mlp_after_attn
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
template
<
typename
T
>
at
::
Tensor
ds_mlp_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
mlp_unfused_cublas
<
T
>
(
output
,
mlp_after_attn
?
input
:
residual
,
residual
,
input_bias
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
preLayerNorm
,
mlp_after_attn
);
return
output
;
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_mlp_gemm_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
input_bias
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
auto
residual_add
=
(
preLayerNorm
?
at
::
empty_like
(
input_cont
)
:
inp_norm
);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
{
output
,
residual_add
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
weight_out
,
const
float
epsilon
,
bool
preLayerNorm
,
bool
async_op
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
intermediate
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight_out
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
rocblas_set_stream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
launch_bias_gelu
((
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
rocblas_operation_none
,
rocblas_operation_none
,
weight_out
.
size
(
1
),
bsz
,
intermediate
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight_out
.
data_ptr
(),
(
T
*
)
intermediate
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
// hipEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return
output
;
}
void
residual_add_bias
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
attention_output
,
at
::
Tensor
&
output_b
,
at
::
Tensor
&
attention_b
,
int
mp_size
,
bool
mlp_after_attn
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
int
hidden_size
=
input
.
size
(
2
);
// hipStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if
(
input
.
scalar_type
()
==
at
::
kFloat
)
if
(
mlp_after_attn
)
launch_bias_residual
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
float
>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
attention_output
.
data_ptr
(),
(
float
*
)
output_b
.
data_ptr
(),
(
float
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
if
(
mlp_after_attn
)
launch_bias_residual
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
bsz
,
hidden_size
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_gptj_residual_add
<
__half
>
((
__half
*
)
input
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
(
__half
*
)
attention_output
.
data_ptr
(),
(
__half
*
)
output_b
.
data_ptr
(),
(
__half
*
)
attention_b
.
data_ptr
(),
hidden_size
,
bsz
,
mp_size
,
Context
::
Instance
().
GetCurrentStream
());
}
std
::
vector
<
at
::
Tensor
>
apply_rotary_pos_emb
(
at
::
Tensor
&
mixed_query
,
at
::
Tensor
&
key_layer
,
unsigned
rotary_dim
,
unsigned
offset
,
unsigned
num_heads
,
bool
rotate_half
,
bool
rotate_every_two
)
{
auto
query_cont
=
mixed_query
.
contiguous
();
auto
key_cont
=
key_layer
.
contiguous
();
unsigned
bsz
=
mixed_query
.
size
(
0
);
unsigned
head_size
=
mixed_query
.
size
(
2
)
/
num_heads
;
unsigned
seq_len
=
mixed_query
.
size
(
1
);
if
(
mixed_query
.
scalar_type
()
==
at
::
kFloat
)
launch_apply_rotary_pos_emb
<
float
>
((
float
*
)
query_cont
.
data_ptr
(),
(
float
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
else
launch_apply_rotary_pos_emb
<
__half
>
((
__half
*
)
query_cont
.
data_ptr
(),
(
__half
*
)
key_cont
.
data_ptr
(),
head_size
,
seq_len
,
rotary_dim
,
offset
,
num_heads
,
bsz
,
rotate_half
,
rotate_every_two
,
Context
::
Instance
().
GetCurrentStream
());
return
{
query_cont
,
key_cont
};
}
template
<
typename
T
>
at
::
Tensor
fused_gemm_gelu_int8
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
bias
,
const
float
epsilon
,
at
::
Tensor
&
q_scale
,
int
groups
,
bool
preLayerNorm
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
.
dtype
(
input_cont
.
options
().
dtype
())
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
quantized_gemm
<
T
>
(
output
,
input_cont
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
output
;
}
at
::
Tensor
moe_res_matmul
(
at
::
Tensor
&
moe_res
,
at
::
Tensor
&
coef
,
at
::
Tensor
&
output
)
{
int
M
=
moe_res
.
size
(
0
)
*
moe_res
.
size
(
1
);
int
N
=
moe_res
.
size
(
2
);
Context
::
Instance
().
SynchComm
();
if
(
moe_res
.
scalar_type
()
==
at
::
kFloat
)
{
launch_moe_res_matmul
<
float
>
((
float
*
)
moe_res
.
data_ptr
(),
(
float
*
)
coef
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
else
{
launch_moe_res_matmul
<
__half
>
((
__half
*
)
moe_res
.
data_ptr
(),
(
__half
*
)
coef
.
data_ptr
(),
(
__half
*
)
output
.
data_ptr
(),
M
,
N
,
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
());
}
return
output
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"softmax_fp32"
,
&
ds_softmax
<
float
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_fp16"
,
&
ds_softmax
<
__half
>
,
"DeepSpeed SoftMax with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp32"
,
&
ds_softmax_context
<
float
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"softmax_context_fp16"
,
&
ds_softmax_context
<
__half
>
,
"DeepSpeed attention with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp32"
,
&
ds_bias_gelu
<
float
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_gelu_fp16"
,
&
ds_bias_gelu
<
__half
>
,
"DeepSpeed Gelu with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp32"
,
&
ds_bias_residual
<
float
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"bias_residual_fp16"
,
&
ds_bias_residual
<
__half
>
,
"DeepSpeed residual-bias add with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp32"
,
&
ds_layernorm
<
float
>
,
"DeepSpeed layer-norm with fp32 (CUDA)"
);
m
.
def
(
"layer_norm_fp16"
,
&
ds_layernorm
<
__half
>
,
"DeepSpeed layer-norm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp32"
,
&
ds_qkv_gemm
<
float
>
,
"DeepSpeed qkv gemm with fp32 (CUDA)"
);
m
.
def
(
"qkv_gemm_fp16"
,
&
ds_qkv_gemm
<
__half
>
,
"DeepSpeed qkv gemm with fp16 (CUDA)"
);
m
.
def
(
"qkv_gemm_int8"
,
&
ds_qkv_gemm_int8
<
__half
>
,
"DeepSpeed qkv gemm with int8 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp32"
,
&
ds_mlp_gemm
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"mlp_gemm_fp16"
,
&
ds_mlp_gemm
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"mlp_gemm_int8"
,
&
ds_mlp_gemm_int8
<
__half
>
,
"DeepSpeed mlp with int8 (CUDA)"
);
m
.
def
(
"vector_matmul_fp32"
,
&
ds_vector_matmul
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"vector_matmul_fp16"
,
&
ds_vector_matmul
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"vector_matmul_int8"
,
&
ds_vector_matmul_int8
<
__half
>
,
"DeepSpeed vector-MM with int8 (CUDA)"
);
m
.
def
(
"linear_layer_fp32"
,
&
ds_linear_layer
<
float
>
,
"DeepSpeed linear_layer with fp32 (CUDA)"
);
m
.
def
(
"linear_layer_fp16"
,
&
ds_linear_layer
<
__half
>
,
"DeepSpeed linear_layer with fp16 (CUDA)"
);
m
.
def
(
"linear_layer_int8"
,
&
ds_linear_layer_int8
<
__half
>
,
"DeepSpeed linear_layer with int8 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp32"
,
&
fused_gemm_gelu
<
float
>
,
"DeepSpeed mlp with fp32 (CUDA)"
);
m
.
def
(
"fused_gemm_gelu_fp16"
,
&
fused_gemm_gelu
<
__half
>
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"residual_add"
,
&
residual_add_bias
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"apply_rotary_pos_emb"
,
&
apply_rotary_pos_emb
,
"DeepSpeed mlp with fp16 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp32"
,
&
einsum_sec_sm_ecm
<
float
>
,
"DeepSpeed vector-MM with fp32 (CUDA)"
);
m
.
def
(
"einsum_sec_sm_ecm_fp16"
,
&
einsum_sec_sm_ecm
<
__half
>
,
"DeepSpeed vector-MM with fp16 (CUDA)"
);
m
.
def
(
"moe_res_matmul"
,
&
moe_res_matmul
,
"DeepSpeed moe residual matmul (CUDA)"
);
}
csrc/transformer_bak/inference/csrc/softmax.cu
0 → 100644
View file @
7d1a83a9
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void
CheckCudaErrorAux
(
const
char
*
file
,
unsigned
line
)
{
cudaError_t
err
=
cudaGetLastError
();
if
(
err
==
cudaSuccess
)
return
;
std
::
cerr
<<
cudaGetErrorString
(
err
)
<<
"("
<<
err
<<
") at "
<<
file
<<
":"
<<
line
<<
std
::
endl
;
throw
std
::
runtime_error
(
"CUDA ERROR!!!
\n
"
);
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace
cg
=
cooperative_groups
;
__global__
void
attn_softmax_v2
(
__half
*
vals
,
__half
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
total_count
,
int
heads
,
int
sequence_length
,
int
num_seq
,
float
scale
,
int
iterations
,
int
reduceWidth
)
{
#ifdef HALF_PRECISION_AVAILABLE
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
float2
low_data
[
MAX_REG_SIZE
];
float2
high_data
[
MAX_REG_SIZE
];
__half2
h_scale
=
__float2half2_rn
(
scale
);
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
0x1f
;
int
warp_num
=
blockDim
.
x
>>
5
;
int
reduce_blocks
=
reduceWidth
>>
5
;
int
seq_lane
=
threadIdx
.
x
%
reduceWidth
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
iter_offset
=
blockIdx
.
x
*
(
warp_num
/
reduce_blocks
)
+
(
wid
/
reduce_blocks
);
if
(
iter_offset
<
total_count
)
{
vals
+=
(
iter_offset
*
sequence_length
);
int
mask_offset
=
(
iter_offset
/
(
heads
*
num_seq
))
*
(
sequence_length
);
int
seq_id
=
iter_offset
%
num_seq
;
int
seq_id4
=
seq_id
>>
2
;
int
real_seq_id
=
seq_id
+
(
num_seq
==
sequence_length
?
0
:
sequence_length
);
int
window_stride4
=
(
local_attention
&&
(
real_seq_id
>>
2
)
>
(
window_size
>>
2
))
?
(
real_seq_id
>>
2
)
-
(
window_size
>>
2
)
:
0
;
int
window_stride
=
(
local_attention
&&
real_seq_id
>=
window_size
)
?
real_seq_id
-
window_size
:
-
1
;
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
((
!
triangular
||
((
data_id
>>
2
)
<=
seq_id4
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
low_data
[
i
].
x
=
data_id
>
window_stride
?
__half2float
(
vals
[
data_id
])
:
minus_infinity
;
low_data
[
i
].
y
=
((
!
triangular
||
((
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
)
?
__half2float
(
vals
[
data_id
+
1
])
:
minus_infinity
;
high_data
[
i
].
x
=
((
!
triangular
||
((
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
)
?
__half2float
(
vals
[
data_id
+
2
])
:
minus_infinity
;
high_data
[
i
].
y
=
((
!
triangular
||
((
data_id
+
3
)
<=
seq_id
))
&&
(
data_id
+
3
)
>
window_stride
)
?
__half2float
(
vals
[
data_id
+
3
])
:
minus_infinity
;
if
(
mask
&&
recompute
)
{
low_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
]);
low_data
[
i
].
y
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
1
]);
high_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
2
]);
high_data
[
i
].
y
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
3
]);
}
}
else
{
low_data
[
i
].
x
=
data_id
>
window_stride
?
__half2float
(
vals
[
data_id
])
:
minus_infinity
;
low_data
[
i
].
y
=
(((
!
triangular
||
(
data_id
+
1
)
<=
seq_id
)
&&
(
data_id
+
1
)
>
window_stride
)
&&
(
data_id
+
1
)
<
sequence_length
)
?
__half2float
(
vals
[
data_id
+
1
])
:
minus_infinity
;
high_data
[
i
].
x
=
(((
!
triangular
||
(
data_id
+
2
)
<=
seq_id
)
&&
(
data_id
+
2
)
>
window_stride
)
&&
(
data_id
+
2
)
<
sequence_length
)
?
__half2float
(
vals
[
data_id
+
2
])
:
minus_infinity
;
high_data
[
i
].
y
=
minus_infinity
;
if
(
mask
&&
recompute
)
{
low_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
]);
if
((
data_id
+
1
)
<
sequence_length
)
low_data
[
i
].
y
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
1
]);
if
((
data_id
+
2
)
<
sequence_length
)
high_data
[
i
].
x
+=
__half2float
(
mask
[
data_id
+
mask_offset
+
2
]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val
=
(
low_data
[
i
].
x
>
max_val
?
low_data
[
i
].
x
:
max_val
);
max_val
=
(
low_data
[
i
].
y
>
max_val
?
low_data
[
i
].
y
:
max_val
);
max_val
=
(
high_data
[
i
].
x
>
max_val
?
high_data
[
i
].
x
:
max_val
);
max_val
=
(
high_data
[
i
].
y
>
max_val
?
high_data
[
i
].
y
:
max_val
);
}
else
{
low_data
[
i
].
x
=
minus_infinity
;
low_data
[
i
].
y
=
minus_infinity
;
high_data
[
i
].
x
=
minus_infinity
;
high_data
[
i
].
y
=
minus_infinity
;
}
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
WARP_SIZE
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
low_data
[
i
].
x
=
__expf
(
low_data
[
i
].
x
-
max_val
);
low_data
[
i
].
y
=
__expf
(
low_data
[
i
].
y
-
max_val
);
high_data
[
i
].
x
=
__expf
(
high_data
[
i
].
x
-
max_val
);
high_data
[
i
].
y
=
__expf
(
high_data
[
i
].
y
-
max_val
);
sum
+=
(
low_data
[
i
].
x
+
low_data
[
i
].
y
+
high_data
[
i
].
x
+
high_data
[
i
].
y
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
WARP_SIZE
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
(
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
vals
[
data_id
]
=
low_data
[
i
].
x
/
sum
;
vals
[
data_id
+
1
]
=
low_data
[
i
].
y
/
sum
;
vals
[
data_id
+
2
]
=
high_data
[
i
].
x
/
sum
;
vals
[
data_id
+
3
]
=
high_data
[
i
].
y
/
sum
;
}
else
{
vals
[
data_id
]
=
low_data
[
i
].
x
/
sum
;
if
((
data_id
+
1
)
<
sequence_length
)
vals
[
data_id
+
1
]
=
low_data
[
i
].
y
/
sum
;
if
((
data_id
+
2
)
<
sequence_length
)
vals
[
data_id
+
2
]
=
high_data
[
i
].
x
/
sum
;
}
}
}
}
#endif
}
__global__
void
attn_softmax_v2
(
float
*
vals
,
float
*
attn_mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
total_count
,
int
heads
,
int
sequence_length
,
int
num_seq
,
float
scale
,
int
iterations
,
int
reduceWidth
)
{
cg
::
thread_block
b
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
WARP_SIZE
>
g
=
cg
::
tiled_partition
<
WARP_SIZE
>
(
b
);
float4
data
[
MAX_REG_SIZE
];
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
0x1f
;
int
warp_num
=
blockDim
.
x
>>
5
;
int
reduce_blocks
=
reduceWidth
>>
5
;
int
seq_lane
=
threadIdx
.
x
%
reduceWidth
;
__shared__
float
partialSum
[
MAX_WARP_NUM
];
int
iter_offset
=
blockIdx
.
x
*
(
warp_num
/
reduce_blocks
)
+
(
wid
/
reduce_blocks
);
if
(
iter_offset
<
total_count
)
{
vals
+=
(
iter_offset
*
sequence_length
);
int
mask_offset
=
(
iter_offset
/
(
heads
*
num_seq
))
*
(
sequence_length
);
int
seq_id
=
iter_offset
%
num_seq
;
int
seq_id4
=
seq_id
>>
2
;
int
real_seq_id
=
seq_id
+
(
num_seq
==
sequence_length
?
0
:
sequence_length
);
int
window_stride4
=
(
local_attention
&&
(
real_seq_id
>>
2
)
>
(
window_size
>>
2
))
?
(
real_seq_id
>>
2
)
-
(
window_size
>>
2
)
:
0
;
int
window_stride
=
(
local_attention
&&
real_seq_id
>=
window_size
)
?
real_seq_id
-
window_size
:
-
1
;
float
max_val
=
minus_infinity
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
((
!
triangular
||
((
data_id
>>
2
)
<=
seq_id4
))
&&
(
data_id
>>
2
)
>=
window_stride4
&&
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
data
[
i
].
x
=
(
data_id
>
window_stride
?
vals
[
data_id
]
:
minus_infinity
);
data
[
i
].
y
=
((
!
triangular
||
((
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
)
?
vals
[
data_id
+
1
]
:
minus_infinity
;
data
[
i
].
z
=
((
!
triangular
||
((
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
)
?
vals
[
data_id
+
2
]
:
minus_infinity
;
data
[
i
].
w
=
((
!
triangular
||
((
data_id
+
3
)
<=
seq_id
))
&&
(
data_id
+
3
)
>
window_stride
)
?
vals
[
data_id
+
3
]
:
minus_infinity
;
if
(
attn_mask
&&
recompute
)
{
data
[
i
].
x
+=
attn_mask
[
data_id
+
mask_offset
];
data
[
i
].
y
+=
attn_mask
[
data_id
+
mask_offset
+
1
];
data
[
i
].
z
+=
attn_mask
[
data_id
+
mask_offset
+
2
];
data
[
i
].
w
+=
attn_mask
[
data_id
+
mask_offset
+
3
];
}
}
else
{
data
[
i
].
x
=
data_id
>
window_stride
?
vals
[
data_id
]
:
minus_infinity
;
data
[
i
].
y
=
(((
!
triangular
||
(
data_id
+
1
)
<=
seq_id
))
&&
(
data_id
+
1
)
>
window_stride
&&
(
data_id
+
1
)
<
sequence_length
)
?
(
vals
[
data_id
+
1
])
:
minus_infinity
;
data
[
i
].
z
=
(((
!
triangular
||
(
data_id
+
2
)
<=
seq_id
))
&&
(
data_id
+
2
)
>
window_stride
&&
(
data_id
+
2
)
<
sequence_length
)
?
(
vals
[
data_id
+
2
])
:
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
if
(
attn_mask
&&
recompute
)
{
data
[
i
].
x
+=
attn_mask
[
data_id
+
mask_offset
];
if
((
data_id
+
1
)
<
sequence_length
)
data
[
i
].
y
+=
attn_mask
[
data_id
+
mask_offset
+
1
];
if
((
data_id
+
2
)
<
sequence_length
)
data
[
i
].
z
+=
attn_mask
[
data_id
+
mask_offset
+
2
];
}
}
max_val
=
(
data
[
i
].
x
>
max_val
?
data
[
i
].
x
:
max_val
);
max_val
=
(
data
[
i
].
y
>
max_val
?
data
[
i
].
y
:
max_val
);
max_val
=
(
data
[
i
].
z
>
max_val
?
data
[
i
].
z
:
max_val
);
max_val
=
(
data
[
i
].
w
>
max_val
?
data
[
i
].
w
:
max_val
);
}
else
{
data
[
i
].
x
=
minus_infinity
;
data
[
i
].
y
=
minus_infinity
;
data
[
i
].
z
=
minus_infinity
;
data
[
i
].
w
=
minus_infinity
;
}
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
max_val
;
b
.
sync
();
if
(
lane
<
warp_num
)
max_val
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
auto
temp
=
g
.
shfl_xor
(
max_val
,
i
);
max_val
=
(
temp
>
max_val
?
temp
:
max_val
);
}
max_val
=
g
.
shfl
(
max_val
,
threadIdx
.
x
/
WARP_SIZE
);
}
float
sum
=
0
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
data
[
i
].
x
=
__expf
(
data
[
i
].
x
-
max_val
);
data
[
i
].
y
=
__expf
(
data
[
i
].
y
-
max_val
);
data
[
i
].
z
=
__expf
(
data
[
i
].
z
-
max_val
);
data
[
i
].
w
=
__expf
(
data
[
i
].
w
-
max_val
);
sum
+=
(
data
[
i
].
x
+
data
[
i
].
y
+
data
[
i
].
z
+
data
[
i
].
w
);
}
for
(
int
i
=
1
;
i
<
WARP_SIZE
;
i
*=
2
)
sum
+=
g
.
shfl_xor
(
sum
,
i
);
if
(
reduceWidth
>
WARP_SIZE
)
{
if
(
lane
==
0
)
partialSum
[
wid
]
=
sum
;
b
.
sync
();
if
(
lane
<
warp_num
)
sum
=
partialSum
[
lane
];
b
.
sync
();
for
(
int
i
=
1
;
i
<
reduce_blocks
;
i
*=
2
)
{
sum
+=
g
.
shfl_xor
(
sum
,
i
);
}
sum
=
g
.
shfl
(
sum
,
threadIdx
.
x
/
WARP_SIZE
);
}
sum
+=
1e-6
;
for
(
int
i
=
0
;
i
<
iterations
;
i
++
)
{
int
data_id
=
i
*
(
reduceWidth
<<
2
)
+
(
seq_lane
<<
2
);
if
(
data_id
<
sequence_length
)
{
if
((
sequence_length
-
data_id
)
>=
4
)
{
vals
[
data_id
]
=
data
[
i
].
x
/
sum
;
vals
[
data_id
+
1
]
=
data
[
i
].
y
/
sum
;
vals
[
data_id
+
2
]
=
data
[
i
].
z
/
sum
;
vals
[
data_id
+
3
]
=
data
[
i
].
w
/
sum
;
}
else
{
vals
[
data_id
]
=
data
[
i
].
x
/
sum
;
if
((
data_id
+
1
)
<
sequence_length
)
vals
[
data_id
+
1
]
=
data
[
i
].
y
/
sum
;
if
((
data_id
+
2
)
<
sequence_length
)
vals
[
data_id
+
2
]
=
data
[
i
].
z
/
sum
;
}
}
}
}
}
template
<
typename
T
>
void
launch_attn_softmax_v2
(
T
*
vals
,
T
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
cudaStream_t
stream
)
{
int
total_count
=
batch_size
*
heads
*
num_seq
;
dim3
grid_dim
((
total_count
-
1
)
/
(
WARP_SIZE
/
((
sequence_length
-
1
)
/
ATTN_THREADS
+
1
))
+
1
);
dim3
block_dim
(
ATTN_THREADS
);
const
int
reduce_width
=
((
sequence_length
-
1
)
/
ATTN_THREADS
+
1
)
*
WARP_SIZE
;
const
int
iterations
=
(
sequence_length
-
1
)
/
(
reduce_width
<<
2
)
+
1
;
if
(
sequence_length
<=
32768
)
attn_softmax_v2
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
vals
,
mask
,
triangular
,
recompute
,
local_attention
,
window_size
,
total_count
,
(
triangular
?
(
heads
*
batch_size
)
:
heads
),
sequence_length
,
num_seq
,
scale
,
iterations
,
reduce_width
);
else
throw
std
::
runtime_error
(
"Unsupport Seq_Length!"
);
}
template
void
launch_attn_softmax_v2
(
float
*
vals
,
float
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
cudaStream_t
stream
);
template
void
launch_attn_softmax_v2
(
__half
*
vals
,
__half
*
mask
,
bool
triangular
,
bool
recompute
,
bool
local_attention
,
int
window_size
,
int
batch_size
,
int
heads
,
int
num_seq
,
int
sequence_length
,
float
scale
,
cudaStream_t
stream
);
csrc/transformer_bak/inference/csrc/softmax.hip
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
hipError_t err = hipGetLastError();
if (err == hipSuccess) return;
std::cerr << hipGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
#ifdef HALF_PRECISION_AVAILABLE
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
__half2 h_scale = __float2half2_rn(scale);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = low_data[i].x / sum;
vals[data_id + 1] = low_data[i].y / sum;
vals[data_id + 2] = high_data[i].x / sum;
vals[data_id + 3] = high_data[i].y / sum;
} else {
vals[data_id] = low_data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum;
}
}
}
}
#endif
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
hipLaunchKernelGGL(( attn_softmax_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
(triangular ? (heads * batch_size) : heads),
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
csrc/transformer_bak/inference/includes/context.h
0 → 100644
View file @
7d1a83a9
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline
int
DS_GET_BLOCKS
(
const
int
N
)
{
return
std
::
max
(
std
::
min
((
N
+
DS_CUDA_NUM_THREADS
-
1
)
/
DS_CUDA_NUM_THREADS
,
DS_MAXIMUM_NUM_BLOCKS
),
// Use at least 1 block, since CUDA does not allow empty block
1
);
}
class
Context
{
public:
Context
()
:
_workspace
(
nullptr
),
_seed
(
42
),
_curr_offset
(
0
),
_stream
(
0
)
{
curandCreateGenerator
(
&
_gen
,
CURAND_RNG_PSEUDO_DEFAULT
);
curandSetPseudoRandomGeneratorSeed
(
_gen
,
123
);
if
(
cublasCreate
(
&
_cublasHandle
)
!=
CUBLAS_STATUS_SUCCESS
)
{
auto
message
=
std
::
string
(
"Fail to create cublas handle."
);
std
::
cerr
<<
message
<<
std
::
endl
;
throw
std
::
runtime_error
(
message
);
}
cublasSetMathMode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
cudaEventCreate
(
&
_comp1_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
cudaEventCreate
(
&
_comp2_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
cudaEventCreate
(
&
_comp_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
cudaEventCreate
(
&
_comm_event
,
(
cudaEventDisableTiming
|
cudaEventBlockingSync
));
}
virtual
~
Context
()
{
cublasDestroy
(
_cublasHandle
);
cudaFree
(
_workspace
);
cudaEventDestroy
(
_comp1_event
);
cudaEventDestroy
(
_comp2_event
);
cudaEventDestroy
(
_comp_event
);
cudaEventDestroy
(
_comm_event
);
}
static
Context
&
Instance
()
{
static
Context
_ctx
;
return
_ctx
;
}
void
GenWorkSpace
(
size_t
size
)
{
if
(
!
_workspace
)
{
assert
(
_workspace
==
nullptr
);
cudaMalloc
(
&
_workspace
,
size
);
}
else
if
(
_workSpaceSize
<
size
)
{
cudaFree
(
_workspace
);
cudaMalloc
(
&
_workspace
,
size
);
}
_workSpaceSize
=
size
;
}
cudaEvent_t
GetCompEvent
(
int
id
)
{
return
id
==
1
?
_comp1_event
:
_comp2_event
;
}
size_t
get_workspace_size
()
const
{
return
_workSpaceSize
;
}
void
*
GetWorkSpace
()
{
return
_workspace
;
}
inline
unsigned
new_token
(
unsigned
layer_id
)
{
if
(
layer_id
==
0
)
_token_length
++
;
return
_token_length
;
}
inline
void
reset_tokens
(
unsigned
initial_tokens
=
0
)
{
_num_tokens
=
initial_tokens
;
}
//_token_length = 0; }
inline
unsigned
current_tokens
()
const
{
return
_num_tokens
;
}
inline
void
advance_tokens
()
{
_num_tokens
++
;
}
curandGenerator_t
&
GetRandGenerator
()
{
return
_gen
;
}
cudaStream_t
GetCommStream
(
bool
async_op
=
false
)
{
if
(
!
_comm_stream
)
_comm_stream
=
async_op
?
at
::
cuda
::
getStreamFromPool
(
true
)
:
at
::
cuda
::
getCurrentCUDAStream
();
return
_comm_stream
;
}
cudaStream_t
GetCurrentStream
(
bool
other_stream
=
false
)
{
// get current pytorch stream.
if
(
other_stream
)
{
if
(
!
_stream
)
_stream
=
at
::
cuda
::
getStreamFromPool
(
true
);
return
_stream
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
return
stream
;
}
cublasHandle_t
GetCublasHandle
()
{
return
_cublasHandle
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
offset_inc
)
{
uint64_t
offset
=
_curr_offset
;
_curr_offset
+=
offset_inc
;
return
std
::
pair
<
uint64_t
,
uint64_t
>
(
_seed
,
offset
);
}
void
SetSeed
(
uint64_t
new_seed
)
{
_seed
=
new_seed
;
}
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
GetGemmAlgos
()
const
{
return
_gemm_algos
;
}
inline
void
SynchComp
()
{
cudaEventRecord
(
_comp_event
,
_comp_stream
);
cudaStreamWaitEvent
(
_comm_stream
,
_comp_event
,
0
);
}
inline
void
SynchComm
()
{
cudaEventRecord
(
_comm_event
,
_comm_stream
);
cudaStreamWaitEvent
(
_comp_stream
,
_comm_event
,
0
);
}
private:
curandGenerator_t
_gen
;
cublasHandle_t
_cublasHandle
;
cudaEvent_t
_comp_event
;
cudaEvent_t
_comm_event
;
void
*
_workspace
;
uint64_t
_seed
;
uint64_t
_curr_offset
;
size_t
_workSpaceSize
;
cudaEvent_t
_comp1_event
;
cudaEvent_t
_comp2_event
;
cudaStream_t
_stream
;
unsigned
_token_length
;
unsigned
_num_tokens
;
std
::
vector
<
std
::
array
<
int
,
3
>>
_gemm_algos
;
cudaStream_t
_comp_stream
;
cudaStream_t
_comm_stream
;
std
::
unordered_map
<
int
,
int
>
_world_sizes
;
};
csrc/transformer_bak/inference/includes/context_hip.h
0 → 100644
View file @
7d1a83a9
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hiprand/hiprand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
hipError_t error_code = callstr; \
if (error_code != hipSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline
int
DS_GET_BLOCKS
(
const
int
N
)
{
return
std
::
max
(
std
::
min
((
N
+
DS_CUDA_NUM_THREADS
-
1
)
/
DS_CUDA_NUM_THREADS
,
DS_MAXIMUM_NUM_BLOCKS
),
// Use at least 1 block, since CUDA does not allow empty block
1
);
}
class
Context
{
public:
Context
()
:
_workspace
(
nullptr
),
_seed
(
42
),
_curr_offset
(
0
),
_stream
(
0
)
{
hiprandCreateGenerator
(
&
_gen
,
HIPRAND_RNG_PSEUDO_DEFAULT
);
hiprandSetPseudoRandomGeneratorSeed
(
_gen
,
123
);
if
(
rocblas_create_handle
(
&
_cublasHandle
)
!=
rocblas_status_success
)
{
auto
message
=
std
::
string
(
"Fail to create cublas handle."
);
std
::
cerr
<<
message
<<
std
::
endl
;
throw
std
::
runtime_error
(
message
);
}
rocblas_set_math_mode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
hipEventCreate
(
&
_comp1_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comp2_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comp_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
hipEventCreate
(
&
_comm_event
,
(
hipEventDisableTiming
|
hipEventBlockingSync
));
}
virtual
~
Context
()
{
rocblas_destroy_handle
(
_cublasHandle
);
hipFree
(
_workspace
);
hipEventDestroy
(
_comp1_event
);
hipEventDestroy
(
_comp2_event
);
hipEventDestroy
(
_comp_event
);
hipEventDestroy
(
_comm_event
);
}
static
Context
&
Instance
()
{
static
Context
_ctx
;
return
_ctx
;
}
void
GenWorkSpace
(
size_t
size
)
{
if
(
!
_workspace
)
{
assert
(
_workspace
==
nullptr
);
hipMalloc
(
&
_workspace
,
size
);
}
else
if
(
_workSpaceSize
<
size
)
{
hipFree
(
_workspace
);
hipMalloc
(
&
_workspace
,
size
);
}
_workSpaceSize
=
size
;
}
hipEvent_t
GetCompEvent
(
int
id
)
{
return
id
==
1
?
_comp1_event
:
_comp2_event
;
}
size_t
get_workspace_size
()
const
{
return
_workSpaceSize
;
}
void
*
GetWorkSpace
()
{
return
_workspace
;
}
inline
unsigned
new_token
(
unsigned
layer_id
)
{
if
(
layer_id
==
0
)
_token_length
++
;
return
_token_length
;
}
inline
void
reset_tokens
(
unsigned
initial_tokens
=
0
)
{
_num_tokens
=
initial_tokens
;
}
//_token_length = 0; }
inline
unsigned
current_tokens
()
const
{
return
_num_tokens
;
}
inline
void
advance_tokens
()
{
_num_tokens
++
;
}
hiprandGenerator_t
&
GetRandGenerator
()
{
return
_gen
;
}
hipStream_t
GetCommStream
(
bool
async_op
=
false
)
{
if
(
!
_comm_stream
)
_comm_stream
=
async_op
?
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
)
:
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
_comm_stream
;
}
hipStream_t
GetCurrentStream
(
bool
other_stream
=
false
)
{
// get current pytorch stream.
if
(
other_stream
)
{
if
(
!
_stream
)
_stream
=
at
::
hip
::
getStreamFromPoolMasqueradingAsCUDA
(
true
);
return
_stream
;
}
hipStream_t
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
return
stream
;
}
rocblas_handle
GetCublasHandle
()
{
return
_cublasHandle
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
offset_inc
)
{
uint64_t
offset
=
_curr_offset
;
_curr_offset
+=
offset_inc
;
return
std
::
pair
<
uint64_t
,
uint64_t
>
(
_seed
,
offset
);
}
void
SetSeed
(
uint64_t
new_seed
)
{
_seed
=
new_seed
;
}
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
GetGemmAlgos
()
const
{
return
_gemm_algos
;
}
inline
void
SynchComp
()
{
hipEventRecord
(
_comp_event
,
_comp_stream
);
hipStreamWaitEvent
(
_comm_stream
,
_comp_event
,
0
);
}
inline
void
SynchComm
()
{
hipEventRecord
(
_comm_event
,
_comm_stream
);
hipStreamWaitEvent
(
_comp_stream
,
_comm_event
,
0
);
}
private:
hiprandGenerator_t
_gen
;
rocblas_handle
_cublasHandle
;
hipEvent_t
_comp_event
;
hipEvent_t
_comm_event
;
void
*
_workspace
;
uint64_t
_seed
;
uint64_t
_curr_offset
;
size_t
_workSpaceSize
;
hipEvent_t
_comp1_event
;
hipEvent_t
_comp2_event
;
hipStream_t
_stream
;
unsigned
_token_length
;
unsigned
_num_tokens
;
std
::
vector
<
std
::
array
<
int
,
3
>>
_gemm_algos
;
hipStream_t
_comp_stream
;
hipStream_t
_comm_stream
;
std
::
unordered_map
<
int
,
int
>
_world_sizes
;
};
csrc/transformer_bak/inference/includes/cublas_wrappers.h
0 → 100644
View file @
7d1a83a9
#pragma once
#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdio.h>
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_32F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_32F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
C
,
CUDA_R_32F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_gemm_ex
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
(
const
void
*
)
alpha
,
(
const
void
*
)
A
,
CUDA_R_16F
,
(
transa
==
CUBLAS_OP_N
)
?
m
:
k
,
(
const
void
*
)
B
,
CUDA_R_16F
,
(
transb
==
CUBLAS_OP_N
)
?
k
:
n
,
(
const
void
*
)
beta
,
(
void
*
)
C
,
CUDA_R_16F
,
m
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
float
*
A
,
const
float
*
B
,
float
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_32F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_32F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d)
\n
"
,
batch
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
int
cublas_strided_batched_gemm
(
cublasHandle_t
handle
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float
*
beta
,
const
__half
*
A
,
const
__half
*
B
,
__half
*
C
,
cublasOperation_t
op_A
,
cublasOperation_t
op_B
,
int
stride_A
,
int
stride_B
,
int
stride_C
,
int
batch
,
cublasGemmAlgo_t
algo
)
{
cublasStatus_t
status
=
cublasGemmStridedBatchedEx
(
handle
,
op_A
,
op_B
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
(
op_A
==
CUBLAS_OP_N
)
?
m
:
k
,
stride_A
,
B
,
CUDA_R_16F
,
(
op_B
==
CUBLAS_OP_N
)
?
k
:
n
,
stride_B
,
beta
,
C
,
CUDA_R_16F
,
m
,
stride_C
,
batch
,
CUDA_R_32F
,
algo
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
fprintf
(
stderr
,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d)
\n
"
,
m
,
n
,
k
,
(
int
)
status
);
return
EXIT_FAILURE
;
}
return
0
;
}
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment