Commit 7d1a83a9 authored by aiss's avatar aiss
Browse files

push Deepspeed 0.6.3 rocm version

parent ab5534fc
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer_hip.h"
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
#include "ds_transformer_hip.h"
static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
const int init_seq_length = 128;
// C++ interface
template <typename T>
unsigned get_workspace_size(unsigned maxBatchSize,
unsigned seq_len,
unsigned hidden_size,
unsigned intermediate_size,
unsigned heads,
bool training,
bool gelu_checkpoint)
{
unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (training) {
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint)
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size);
}
return workSpacesize; // * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template <typename T>
BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
unsigned batch_size,
unsigned hidden_size,
unsigned num_heads,
unsigned intermediate_size,
unsigned seq_length,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
: _layer_id(layer_id),
_batch_size(batch_size),
_hidden_size(hidden_size),
_heads(num_heads),
_intermediate_size(intermediate_size),
_seq_length(seq_length),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_attn_dropout_checkpoint(attn_dropout_checkpoint),
_normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size,
hidden_size,
gemm_algos[0])),
_attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
hidden_size,
gemm_algos[0])),
_attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
_intermediate_size,
hidden_size,
gemm_algos[1])),
_ff2(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
_intermediate_size,
gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_seq_length,
_seq_length,
_hidden_size / _heads,
//aiss debug 0506
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
(T(1.0 / (sqrt(_hidden_size / _heads)))),
T(0.0),
rocblas_operation_transpose,
rocblas_operation_none,
gemm_algos[3])),
_attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_hidden_size / _heads,
_seq_length,
_seq_length,
T(1.0),
T(0.0),
rocblas_operation_none,
rocblas_operation_none,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
Initialize();
}
template <typename T>
BertTransformerLayer<T>::~BertTransformerLayer()
{
}
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
#ifndef __HIP_PLATFORM_HCC__
if (std::is_same<T, __half>::value) rocblas_set_math_mode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
#endif
}
template <typename T>
void BertTransformerLayer<T>::Forward(unsigned bsz,
const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_qkvb_ptr,
const T* attn_ow_ptr,
const T* attn_ob_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* output_b_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* out_ptr,
T* inp_norm_ptr,
T* q_tf_ptr,
T* k_tf_ptr,
T* v_tf_ptr,
T* soft_out_ptr,
T* ctx_bufB_ptr,
T* attn_o_inp_ptr,
T* add_res_ptr,
T* ff1_inp_ptr,
T* gelu_inp_ptr,
T* ff2_inp_ptr)
{
rocblas_set_stream(_cublasHandle, _stream);
if (!_stochastic_mode) hipStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1;
if (_normalize_invertible) {
add_res_ptr = buf_1 + 3 * small_buf_size;
buf_2 = add_res_ptr;
}
if (_gelu_checkpoint) buf_2 += small_buf_size;
if (_attn_dropout_checkpoint)
ctx_bufB_ptr =
(_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size)
: (buf_1 + 4 * small_buf_size));
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_layer_norm.Forward(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else
_qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
launch_bias_add_transform_0213<T>(
q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);
int bsz_heads = bsz * _heads;
// attention scores
_attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
// Softmax + Mask
_softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);
// attn prob dropout.
_attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);
// attention context
_attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);
launch_transform4d_0213<T>(
attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);
if (_pre_or_postLayerNorm)
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
else
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);
// attn output dropout.
if (_pre_or_postLayerNorm)
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
else
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
}
_ff1.Forward(bsz_seq,
ff1_inp_ptr,
inter_w_ptr,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle);
_gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
_stream);
_ff2.Forward(
bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle);
// layer output dropout.
if (_pre_or_postLayerNorm)
_layer_output_dropout.ForwardWithBias(
bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
else
_layer_output_dropout.ForwardWithBias(
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_layer_norm.Forward(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
}
template <typename T>
void BertTransformerLayer<T>::Backward(unsigned bsz,
const T* grad_output_ptr,
const T* input_ptr,
const T* output_ptr,
const T* inp_norm_ptr,
const T* q_tf_ptr,
const T* k_tf_ptr,
const T* v_tf_ptr,
const T* soft_out_ptr,
const T* ctx_bufB_ptr,
const T* attn_o_inp_ptr,
const T* add_res_ptr,
const T* ff1_inp_ptr,
const T* gelu_inp_ptr,
const T* ff2_inp_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_ow_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* grad_input_ptr,
T* grad_attn_qkvw_ptr,
T* grad_attn_qkvb_ptr,
T* grad_attn_ow_ptr,
T* grad_attn_ob_ptr,
T* grad_attn_nw_ptr,
T* grad_attn_nb_ptr,
T* grad_inter_w_ptr,
T* grad_inter_b_ptr,
T* grad_output_w_ptr,
T* grad_output_b_ptr,
T* grad_norm_w_ptr,
T* grad_norm_b_ptr)
{
rocblas_set_stream(_cublasHandle, _stream);
if (!_stochastic_mode) hipStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1 + small_buf_size;
T* buf_3 = buf_2 + small_buf_size;
T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size)
: buf_3 + small_buf_size);
T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
hipStream_t streams[2] = {_stream, _stream};
int bsz_seq = bsz * _seq_length;
int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
inp_norm_ptr);
else
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
output_ptr);
}
if (_pre_or_postLayerNorm)
_layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
else
_layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);
const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
if (_gelu_checkpoint)
_gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq,
layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
output_w_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
_cublasHandle,
_stream,
ff2_buf);
_gelu.Backward(
bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq,
ff2_buf,
ff1_inp_ptr,
inter_w_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
_cublasHandle,
_stream,
buf_3);
if (!_pre_or_postLayerNorm)
launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
} else {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
}
_attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;
_attn_out_linear.Backward(bsz_seq,
attn_output_dropout_buf,
attn_o_inp_ptr,
attn_ow_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
_cublasHandle,
_stream,
buf_1);
launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);
if (_attn_prob_dropout.HasDropout()) {
if (_attn_dropout_checkpoint)
_attn_prob_dropout.Forward(
bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);
_attn_context.Backward(bsz_heads,
buf_2,
v_tf_ptr,
(_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
_cublasHandle,
buf_3,
ff2_buf);
} else
_attn_context.Backward(
bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);
_attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);
_softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);
_attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);
launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);
if (_pre_or_postLayerNorm)
_qkv_linear.Backward(bsz_seq,
ff2_buf,
inp_norm_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
else
_qkv_linear.Backward(bsz_seq,
ff2_buf,
input_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
if (_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
input_ptr);
else
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
inp_norm_ptr);
} else
launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
}
template <typename T>
void BertTransformerLayer<T>::SetTrainingMode(bool training)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout.SetTrainingMode(training);
_attn_output_dropout.SetTrainingMode(training);
_layer_output_dropout.SetTrainingMode(training);
}
template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr,
T* attn_layer_norm_var,
T* attn_layer_norm_mean,
T* layer_norm_var,
T* layer_norm_mean)
{
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
_attn_layer_norm.SetVar(attn_layer_norm_var);
_attn_layer_norm.SetMean(attn_layer_norm_mean);
_layer_norm.SetVar(layer_norm_var);
_layer_norm.SetMean(layer_norm_mean);
}
template <typename T>
void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
{
_seq_length = seq_len;
_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
}
template <typename T>
int create_transformer_layer(unsigned layer_id,
unsigned batch_size,
unsigned hidden_dim,
unsigned num_heads,
unsigned intermediate_size,
float attn_dropout_ratio,
float hidden_dropout_ratio,
float layer_norm_eps,
int seed,
bool pre_or_postLayerNorm,
bool test_gemm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
s_transformer_layers[layer_id] = layer;
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
<< std::endl;
return 0;
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b,
bool training_mode,
bool prelayernorm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint)
{
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
unsigned bsz = input.size(0);
const T* input_ptr = (const T*)input.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* output_b_ptr = (const T*)output_b.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
auto output = torch::empty_like(input);
T* out_ptr = (T*)output.data_ptr();
auto options = torch::TensorOptions()
.dtype(input.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto uint8_options = torch::TensorOptions()
.dtype(torch::kInt8)
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(false);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
unsigned seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len);
}
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
auto attn_output_dropout_mask =
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask =
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr();
T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
torch::Tensor gelu_inp =
(gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
torch::Tensor soft_out =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
torch::Tensor ctx_bufB =
(attn_dropout_checkpoint
? soft_out
: torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Forward(bsz,
input_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_qkvb_ptr,
attn_ow_ptr,
attn_ob_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
output_b_ptr,
norm_w_ptr,
norm_b_ptr,
out_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr);
return {output,
inp_norm,
qkv_tf,
soft_out,
ctx_bufB,
attn_o_inp,
add_res,
ff1_inp,
gelu_inp,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean};
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
const torch::Tensor& grad_output,
const torch::Tensor& output,
const torch::Tensor& inp_norm,
const torch::Tensor& qkv_tf,
const torch::Tensor& soft_out,
const torch::Tensor& ctx_bufB,
const torch::Tensor& attn_o_inp,
const torch::Tensor& add_res,
const torch::Tensor& ff1_inp,
const torch::Tensor& gelu_inp,
const torch::Tensor& ff2_inp,
const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask,
const torch::Tensor& attn_layer_norm_var,
const torch::Tensor& attn_layer_norm_mean,
const torch::Tensor& layer_norm_var,
const torch::Tensor& layer_norm_mean,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b)
{
auto g_output = grad_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
CHECK_INPUT(inp_norm);
CHECK_INPUT(qkv_tf);
CHECK_INPUT(add_res);
CHECK_INPUT(soft_out);
CHECK_INPUT(ctx_bufB);
CHECK_INPUT(attn_o_inp);
CHECK_INPUT(ff1_inp);
CHECK_INPUT(gelu_inp);
CHECK_INPUT(ff2_inp);
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
unsigned bsz = g_output.size(0);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
unsigned seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len);
}
auto options = torch::TensorOptions()
.dtype(g_output.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
auto grad_attn_ow = torch::empty_like(attn_ow);
auto grad_attn_ob = torch::empty_like(attn_ob);
auto grad_attn_nw = torch::empty_like(attn_nw);
auto grad_attn_nb = torch::empty_like(attn_nb);
auto grad_inter_w = torch::empty_like(inter_w);
auto grad_inter_b = torch::empty_like(inter_b);
auto grad_output_w = torch::empty_like(output_w);
auto grad_output_b = torch::empty_like(output_b);
auto grad_norm_w = torch::empty_like(norm_w);
auto grad_norm_b = torch::empty_like(norm_b);
// inputs.
const T* grad_output_ptr = (const T*)g_output.data_ptr();
const T* input_ptr = (const T*)input.data_ptr();
const T* output_ptr = (const T*)output.data_ptr();
const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
const T* add_res_ptr = (const T*)add_res.data_ptr();
const T* k_tf_ptr =
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr();
const T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr();
const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
const T* soft_out_ptr = (const T*)soft_out.data_ptr();
const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
// outputs.
T* grad_input_ptr = (T*)grad_input.data_ptr();
T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Backward(bsz,
grad_output_ptr,
input_ptr,
output_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_ow_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
norm_w_ptr,
norm_b_ptr,
grad_input_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr);
return {grad_input,
grad_attn_qkvw,
grad_attn_qkvb,
grad_attn_ow,
grad_attn_ob,
grad_attn_nw,
grad_attn_nb,
grad_inter_w,
grad_inter_b,
grad_output_w,
grad_output_b,
grad_norm_w,
grad_norm_b};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward_fp32",
&ds_transformer_forward<float>,
"DeepSpeed Transformer forward with fp32 (CUDA)");
m.def("forward_fp16",
&ds_transformer_forward<__half>,
"DeepSpeed Transformer forward with fp16 (CUDA)");
m.def("backward_fp32",
&ds_transformer_backward<float>,
"DeepSpeed Transformer backward with fp32 (CUDA)");
m.def("backward_fp16",
&ds_transformer_backward<__half>,
"DeepSpeed Transformer backward with fp16 (CUDA)");
m.def("create_transformer_layer_fp32",
&create_transformer_layer<float>,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
m.def("create_transformer_layer_fp16",
&create_transformer_layer<__half>,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
}
#include "custom_cuda_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(
input, bias, output, intermediate_size / 4, iterations);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(
input, output, intermediate_size / 4, iterations);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
cudaStream_t);
template void launch_gelu<float>(const float*, float*, int, int, cudaStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
cudaStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(
d_output, input, bias, intermediate_size / 4, iterations);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, cudaStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, output, intermediate_size / 4, iterations);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, intermediate_size / 4, iterations);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream,
d_output, input, bias, intermediate_size / 4, iterations);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
#include "general_kernels.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < width) out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
cudaStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
cudaStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
cudaStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
cudaStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
cudaStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "general_kernels_hip.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < width) out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
hipStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned half_dim = rotary_dim >> 1;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
cudaStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
else if (rotate_half)
apply_rotary_pos_emb1<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
cudaStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
cudaStream_t);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
cudaStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
*/
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned half_dim = rotary_dim >> 1;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
hipLaunchKernelGGL(( apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
else if (rotate_half)
hipLaunchKernelGGL(( apply_rotary_pos_emb1), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
hipLaunchKernelGGL((
apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
*/
#include "custom_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
#endif
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
cudaStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
dequantize_kernel<<<grid_dims, block_dims, 0, stream>>>(
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
#ifdef HALF_PRECISION_AVAILABLE
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
#endif
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
hipStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
#include "custom_cuda_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
__global__ void fused_bias_gelu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
input_cast[offset] = data;
}
}
__global__ void fused_bias_gelu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
cudaStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(
input, bias, total_count, intermediate_size / 4);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t);
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
input_cast[offset] = data;
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(input, bias, total_count, hidden_size / 4);
}
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t);
__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x =
(high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y =
(high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void
launch_bias_residual<float>(float*, float*, float*, float*, float*, int, int, int, cudaStream_t);
template void launch_bias_residual<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
cudaStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_size)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
low_data.y =
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
high_data.x =
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
high_data.y =
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
gptj_residual_add<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
cudaStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
cudaStream_t);
__global__ void moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float4* residual_cast = reinterpret_cast<float4*>(residual);
float4* coef_cast = reinterpret_cast<float4*>(coef);
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
float4* coef_cast2 = coef_cast + hidden_dim;
while (tid < hidden_dim) {
float4 res = residual_cast[tid];
float4 mlp = mlp_out_cast[tid];
float4 coef1 = coef_cast[tid];
float4 coef2 = coef_cast2[tid];
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
mlp_out_cast[tid] = mlp;
tid += blockDim.x;
}
}
__global__ void moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float2* residual_cast = reinterpret_cast<float2*>(residual);
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
float2* coef_cast = reinterpret_cast<float2*>(coef);
float2* coef_cast2 = coef_cast + hidden_dim;
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
while (tid < hidden_dim) {
float2 res = residual_cast[tid];
float2 coef1 = coef_cast[tid];
float2 coef2 = coef_cast[tid];
float2 data = mlp_out_cast[tid];
__half* data_h = reinterpret_cast<__half*>(&data);
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
__half* res_h = reinterpret_cast<__half*>(&res);
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];
mlp_out_cast[tid] = data;
tid += blockDim.x;
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
moe_res_matmul<<<grid_dim, block_dim, 0, stream>>>(
residual, coef, mlp_out, seq_len, hidden_dim / 4);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
__global__ void fused_bias_gelu(float* input,
const float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
input_cast[offset] = data;
}
}
__global__ void fused_bias_gelu(__half* input,
const __half* bias,
int total_count,
int intermediate_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size / 4);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
input_cast[offset] = data;
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
input_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, hipStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
hipLaunchKernelGGL(( fused_bias_add), dim3(grid_dims), dim3(block_dims), 0, stream, input, bias, total_count, hidden_size / 4);
}
template void launch_bias_add<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x =
(high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y =
(high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_bias_residual(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( fused_bias_residual), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void
launch_bias_residual<float>(float*, float*, float*, float*, float*, int, int, int, hipStream_t);
template void launch_bias_residual<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w);
output_cast[offset] = data;
}
}
__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_size)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
low_data.x =
low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x));
low_data.y =
low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y));
high_data.x =
high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x));
high_data.y =
high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
output_cast[offset] = vals_vec;
}
#endif
}
template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( gptj_residual_add), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
hipStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
__global__ void moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float4* residual_cast = reinterpret_cast<float4*>(residual);
float4* coef_cast = reinterpret_cast<float4*>(coef);
float4* mlp_out_cast = reinterpret_cast<float4*>(mlp_out);
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
float4* coef_cast2 = coef_cast + hidden_dim;
while (tid < hidden_dim) {
float4 res = residual_cast[tid];
float4 mlp = mlp_out_cast[tid];
float4 coef1 = coef_cast[tid];
float4 coef2 = coef_cast2[tid];
mlp.x = mlp.x * coef2.x + res.x * coef1.x;
mlp.y = mlp.y * coef2.y + res.y * coef1.y;
mlp.z = mlp.z * coef2.z + res.z * coef1.z;
mlp.w = mlp.w * coef2.w + res.w * coef1.w;
mlp_out_cast[tid] = mlp;
tid += blockDim.x;
}
}
__global__ void moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim)
{
unsigned tid = threadIdx.x;
float2* residual_cast = reinterpret_cast<float2*>(residual);
float2* mlp_out_cast = reinterpret_cast<float2*>(mlp_out);
float2* coef_cast = reinterpret_cast<float2*>(coef);
float2* coef_cast2 = coef_cast + hidden_dim;
residual_cast += blockIdx.x * hidden_dim;
mlp_out_cast += blockIdx.x * hidden_dim;
while (tid < hidden_dim) {
float2 res = residual_cast[tid];
float2 coef1 = coef_cast[tid];
float2 coef2 = coef_cast[tid];
float2 data = mlp_out_cast[tid];
__half* data_h = reinterpret_cast<__half*>(&data);
__half* coef1_h = reinterpret_cast<__half*>(&coef1);
__half* coef2_h = reinterpret_cast<__half*>(&coef2);
__half* res_h = reinterpret_cast<__half*>(&res);
data_h[0] = res_h[0] * coef1_h[0] + data_h[0] * coef2_h[0];
data_h[1] = res_h[1] * coef1_h[1] + data_h[1] * coef2_h[1];
data_h[2] = res_h[2] * coef1_h[2] + data_h[2] * coef2_h[2];
data_h[3] = res_h[3] * coef1_h[3] + data_h[3] * coef2_h[3];
mlp_out_cast[tid] = data;
tid += blockDim.x;
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
hipLaunchKernelGGL(( moe_res_matmul), dim3(grid_dim), dim3(block_dim), 0, stream,
residual, coef, mlp_out, seq_len, hidden_dim / 4);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "context.h"
#include "cublas_wrappers.h"
#include "custom_cuda_layers.h"
std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
#define MAX_OUT_TOKES 10
template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
bool async_op)
{
auto attn_scores_c = attn_scores.contiguous();
int bsz = attn_scores_c.size(0);
int seq_len = attn_scores_c.size(1);
int len = attn_scores_c.sizes().size();
if (len > 3) seq_len = attn_scores_c.size(2);
int soft_len = attn_scores_c.size(2);
if (len > 3) soft_len = attn_scores_c.size(3);
int heads = 1;
if (len > 3) heads = attn_scores_c.size(1);
launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
1.0,
Context::Instance().GetCurrentStream(async_op));
return attn_scores_c;
}
template <typename T>
void allocate_workspace(size_t hidden_dim,
size_t max_seq_len,
size_t batch_size,
size_t head_size = 128)
{
size_t _workSpaceSize = (hidden_dim * batch_size * max_seq_len);
Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T));
}
template <typename T>
at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
{
auto options = at::TensorOptions()
.dtype(Q.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
float alpha = 1;
float gemm_beta = 0.0;
if (!workspace) {
allocate_workspace<T>(W.size(1), MAX_OUT_TOKES, Q.size(0));
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options);
unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_T,
m,
n,
k,
&alpha,
&gemm_beta,
(T*)W.data_ptr(),
(T*)Q.data_ptr(),
(T*)O.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return O;
}
template <typename T>
void attention_unfused(at::Tensor& prev_key_cont,
at::Tensor& query_cont,
at::Tensor& attn_mask,
at::Tensor& prev_value_cont,
at::Tensor& output,
int& bsz,
int& seq_len,
int& soft_len,
int& heads,
float& norm_factor,
bool triangular,
bool recompute,
bool local_attention,
int window_size)
{
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
float alpha = norm_factor;
float gemm_beta = 0.0;
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
int k = prev_value_cont.size(2) / heads;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
&alpha,
&gemm_beta,
(T*)prev_key_cont.data_ptr(),
(T*)query_cont.data_ptr(),
(T*)attn_score.data_ptr(),
CUBLAS_OP_N,
CUBLAS_OP_N,
soft_len * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
attn_score = ds_softmax<T>(
attn_score, attn_mask, triangular, recompute, local_attention, window_size, false);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
&alpha,
&gemm_beta,
(T*)prev_value_cont.data_ptr(),
(T*)attn_score.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_OP_N,
CUBLAS_OP_N,
soft_len * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
at::Tensor& prev_key,
at::Tensor& new_key,
at::Tensor& attn_mask,
at::Tensor& prev_value,
at::Tensor& new_value,
int heads,
float norm_factor,
bool merging,
bool triangular,
bool local_attention,
int window_size,
bool no_masking)
{
auto query_cont = query.contiguous();
auto prev_key_cont = prev_key.contiguous();
auto prev_value_cont = prev_value.contiguous();
int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int bsz = query_cont.size(0);
int seq_len = query_cont.size(1);
int soft_len = prev_value.size(1);
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output =
at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options);
attention_unfused<T>(prev_key_cont,
query_cont,
attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont,
output,
bsz,
seq_len,
soft_len,
heads,
norm_factor,
(triangular && (new_size == 0)),
(new_size == 0),
local_attention,
window_size);
return {output, prev_key, prev_value};
}
template <typename T>
at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);
launch_bias_gelu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto residual_cont = residual.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon)
{
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
launch_layer_norm((T*)inp_norm.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)gamma.data_ptr(),
(T*)betta.data_ptr(),
epsilon,
bsz,
input_cont.size(2),
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto inp_norm = ds_layernorm<T>(input, gamma, beta, epsilon);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
int bsz = input.size(0) * input.size(1);
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm =
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
return {output, inp_norm};
}
template <typename T>
void quantized_gemm(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int merge_count)
{
int bsz = input.size(0) * input.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto weight16 = at::empty({weight.size(0), weight.size(1)}, options);
launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(1),
weight.size(0),
groups,
merge_count,
Context::Instance().GetCurrentStream());
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight16.data_ptr(),
(T*)input.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool add_bias)
{
int bsz = input.size(0) * input.size(1);
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto inp_norm = ds_layernorm<T>(input_cont, gamma, beta, epsilon);
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& q_scale,
int groups)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int bsz = input_cont.size(0) * input_cont.size(1);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return output;
}
template <typename T>
at::Tensor ds_vector_matmul_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
int groups,
int merge_count)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, merge_count);
return output;
}
template <typename T>
void mlp_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
launch_residual_layer_norm((T*)inp_norm.data_ptr(),
(T*)nullptr,
(T*)input.data_ptr(),
(T*)residual.data_ptr(),
(T*)input_bias.data_ptr(),
(T*)gamma.data_ptr(),
(T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
preLayerNorm,
mlp_after_attn,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}
template <typename T>
at::Tensor ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight,
bias,
gamma,
beta,
epsilon,
preLayerNorm,
mlp_after_attn);
return output;
}
template <typename T>
std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return {output, residual_add};
}
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& weight_out,
const float epsilon,
bool preLayerNorm,
bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
}
void residual_add_bias(at::Tensor& output,
at::Tensor& input,
at::Tensor& attention_output,
at::Tensor& output_b,
at::Tensor& attention_b,
int mp_size,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
// cudaStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if (input.scalar_type() == at::kFloat)
if (mlp_after_attn)
launch_bias_residual((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<float>((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
else if (mlp_after_attn)
launch_bias_residual((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
}
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
at::Tensor& key_layer,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
bool rotate_half,
bool rotate_every_two)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
unsigned bsz = mixed_query.size(0);
unsigned head_size = mixed_query.size(2) / num_heads;
unsigned seq_len = mixed_query.size(1);
if (mixed_query.scalar_type() == at::kFloat)
launch_apply_rotary_pos_emb<float>((float*)query_cont.data_ptr(),
(float*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
return {query_cont, key_cont};
}
template <typename T>
at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output)
{
int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2);
Context::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(),
(float*)output.data_ptr(),
M,
N,
at::cuda::getCurrentCUDAStream());
} else {
launch_moe_res_matmul<__half>((__half*)moe_res.data_ptr(),
(__half*)coef.data_ptr(),
(__half*)output.data_ptr(),
M,
N,
at::cuda::getCurrentCUDAStream());
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp32 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("layer_norm_fp32", &ds_layernorm<float>, "DeepSpeed layer-norm with fp32 (CUDA)");
m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
m.def("mlp_gemm_fp32", &ds_mlp_gemm<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)");
m.def("vector_matmul_fp32", &ds_vector_matmul<float>, "DeepSpeed vector-MM with fp32 (CUDA)");
m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)");
m.def("vector_matmul_int8",
&ds_vector_matmul_int8<__half>,
"DeepSpeed vector-MM with int8 (CUDA)");
m.def("linear_layer_fp32", &ds_linear_layer<float>, "DeepSpeed linear_layer with fp32 (CUDA)");
m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)");
m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add", &residual_add_bias, "DeepSpeed mlp with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
"DeepSpeed vector-MM with fp32 (CUDA)");
m.def("einsum_sec_sm_ecm_fp16",
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
#define MAX_OUT_TOKES 10
template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
bool async_op)
{
auto attn_scores_c = attn_scores.contiguous();
int bsz = attn_scores_c.size(0);
int seq_len = attn_scores_c.size(1);
int len = attn_scores_c.sizes().size();
if (len > 3) seq_len = attn_scores_c.size(2);
int soft_len = attn_scores_c.size(2);
if (len > 3) soft_len = attn_scores_c.size(3);
int heads = 1;
if (len > 3) heads = attn_scores_c.size(1);
launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
1.0,
Context::Instance().GetCurrentStream(async_op));
return attn_scores_c;
}
template <typename T>
void allocate_workspace(size_t hidden_dim,
size_t max_seq_len,
size_t batch_size,
size_t head_size = 128)
{
size_t _workSpaceSize = (hidden_dim * batch_size * max_seq_len);
Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T));
}
template <typename T>
at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
{
auto options = at::TensorOptions()
.dtype(Q.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
float alpha = 1;
float gemm_beta = 0.0;
if (!workspace) {
allocate_workspace<T>(W.size(1), MAX_OUT_TOKES, Q.size(0));
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options);
unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_transpose,
m,
n,
k,
&alpha,
&gemm_beta,
(T*)W.data_ptr(),
(T*)Q.data_ptr(),
(T*)O.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return O;
}
template <typename T>
void attention_unfused(at::Tensor& prev_key_cont,
at::Tensor& query_cont,
at::Tensor& attn_mask,
at::Tensor& prev_value_cont,
at::Tensor& output,
int& bsz,
int& seq_len,
int& soft_len,
int& heads,
float& norm_factor,
bool triangular,
bool recompute,
bool local_attention,
int window_size)
{
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
float alpha = norm_factor;
float gemm_beta = 0.0;
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
int k = prev_value_cont.size(2) / heads;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
&alpha,
&gemm_beta,
(T*)prev_key_cont.data_ptr(),
(T*)query_cont.data_ptr(),
(T*)attn_score.data_ptr(),
rocblas_operation_none,
rocblas_operation_none,
soft_len * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
attn_score = ds_softmax<T>(
attn_score, attn_mask, triangular, recompute, local_attention, window_size, false);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
&alpha,
&gemm_beta,
(T*)prev_value_cont.data_ptr(),
(T*)attn_score.data_ptr(),
(T*)output.data_ptr(),
rocblas_operation_none,
rocblas_operation_none,
soft_len * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query,
at::Tensor& prev_key,
at::Tensor& new_key,
at::Tensor& attn_mask,
at::Tensor& prev_value,
at::Tensor& new_value,
int heads,
float norm_factor,
bool merging,
bool triangular,
bool local_attention,
int window_size,
bool no_masking)
{
auto query_cont = query.contiguous();
auto prev_key_cont = prev_key.contiguous();
auto prev_value_cont = prev_value.contiguous();
int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int bsz = query_cont.size(0);
int seq_len = query_cont.size(1);
int soft_len = prev_value.size(1);
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output =
at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options);
attention_unfused<T>(prev_key_cont,
query_cont,
attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont,
output,
bsz,
seq_len,
soft_len,
heads,
norm_factor,
(triangular && (new_size == 0)),
(new_size == 0),
local_attention,
window_size);
return {output, prev_key, prev_value};
}
template <typename T>
at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);
launch_bias_gelu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto residual_cont = residual.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon)
{
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
launch_layer_norm((T*)inp_norm.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)gamma.data_ptr(),
(T*)betta.data_ptr(),
epsilon,
bsz,
input_cont.size(2),
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto inp_norm = ds_layernorm<T>(input, gamma, beta, epsilon);
// hipEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
int bsz = input.size(0) * input.size(1);
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return inp_norm;
}
template <typename T>
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm =
qkv_unfused_cublas<T>(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias);
return {output, inp_norm};
}
template <typename T>
void quantized_gemm(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int merge_count)
{
int bsz = input.size(0) * input.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto weight16 = at::empty({weight.size(0), weight.size(1)}, options);
launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(1),
weight.size(0),
groups,
merge_count,
Context::Instance().GetCurrentStream());
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight16.data_ptr(),
(T*)input.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool add_bias)
{
int bsz = input.size(0) * input.size(1);
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto inp_norm = ds_layernorm<T>(input_cont, gamma, beta, epsilon);
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& q_scale,
int groups)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int bsz = input_cont.size(0) * input_cont.size(1);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return output;
}
template <typename T>
at::Tensor ds_vector_matmul_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
int groups,
int merge_count)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, merge_count);
return output;
}
template <typename T>
void mlp_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
auto inp_norm = at::empty_like(input);
launch_residual_layer_norm((T*)inp_norm.data_ptr(),
(T*)nullptr,
(T*)input.data_ptr(),
(T*)residual.data_ptr(),
(T*)input_bias.data_ptr(),
(T*)gamma.data_ptr(),
(T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
preLayerNorm,
mlp_after_attn,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)inp_norm.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}
template <typename T>
at::Tensor ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight,
bias,
gamma,
beta,
epsilon,
preLayerNorm,
mlp_after_attn);
return output;
}
template <typename T>
std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return {output, residual_add};
}
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& weight_out,
const float epsilon,
bool preLayerNorm,
bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
// hipEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
}
void residual_add_bias(at::Tensor& output,
at::Tensor& input,
at::Tensor& attention_output,
at::Tensor& output_b,
at::Tensor& attention_b,
int mp_size,
bool mlp_after_attn)
{
int bsz = input.size(0) * input.size(1);
int hidden_size = input.size(2);
// hipStreamWaitEvent(
// Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0);
if (input.scalar_type() == at::kFloat)
if (mlp_after_attn)
launch_bias_residual((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<float>((float*)input.data_ptr(),
(float*)output.data_ptr(),
(float*)attention_output.data_ptr(),
(float*)output_b.data_ptr(),
(float*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
else if (mlp_after_attn)
launch_bias_residual((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
bsz,
hidden_size,
mp_size,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<__half>((__half*)input.data_ptr(),
(__half*)output.data_ptr(),
(__half*)attention_output.data_ptr(),
(__half*)output_b.data_ptr(),
(__half*)attention_b.data_ptr(),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
}
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
at::Tensor& key_layer,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
bool rotate_half,
bool rotate_every_two)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
unsigned bsz = mixed_query.size(0);
unsigned head_size = mixed_query.size(2) / num_heads;
unsigned seq_len = mixed_query.size(1);
if (mixed_query.scalar_type() == at::kFloat)
launch_apply_rotary_pos_emb<float>((float*)query_cont.data_ptr(),
(float*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream());
return {query_cont, key_cont};
}
template <typename T>
at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output)
{
int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2);
Context::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(),
(float*)output.data_ptr(),
M,
N,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
} else {
launch_moe_res_matmul<__half>((__half*)moe_res.data_ptr(),
(__half*)coef.data_ptr(),
(__half*)output.data_ptr(),
M,
N,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp32 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("layer_norm_fp32", &ds_layernorm<float>, "DeepSpeed layer-norm with fp32 (CUDA)");
m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
m.def("mlp_gemm_fp32", &ds_mlp_gemm<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)");
m.def("vector_matmul_fp32", &ds_vector_matmul<float>, "DeepSpeed vector-MM with fp32 (CUDA)");
m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)");
m.def("vector_matmul_int8",
&ds_vector_matmul_int8<__half>,
"DeepSpeed vector-MM with int8 (CUDA)");
m.def("linear_layer_fp32", &ds_linear_layer<float>, "DeepSpeed linear_layer with fp32 (CUDA)");
m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)");
m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add", &residual_add_bias, "DeepSpeed mlp with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
"DeepSpeed vector-MM with fp32 (CUDA)");
m.def("einsum_sec_sm_ecm_fp16",
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
}
#include <limits>
#include "custom_cuda_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
cudaError_t err = cudaGetLastError();
if (err == cudaSuccess) return;
std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
#ifdef HALF_PRECISION_AVAILABLE
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
__half2 h_scale = __float2half2_rn(scale);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = low_data[i].x / sum;
vals[data_id + 1] = low_data[i].y / sum;
vals[data_id + 2] = high_data[i].x / sum;
vals[data_id + 3] = high_data[i].y / sum;
} else {
vals[data_id] = low_data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum;
}
}
}
}
#endif
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
cudaStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
attn_softmax_v2<<<grid_dim, block_dim, 0, stream>>>(
vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
(triangular ? (heads * batch_size) : heads),
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
cudaStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
cudaStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
//#include <cuda_profiler_api.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 1024
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
hipError_t err = hipGetLastError();
if (err == hipSuccess) return;
std::cerr << hipGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
#ifdef HALF_PRECISION_AVAILABLE
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
__half2 h_scale = __float2half2_rn(scale);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3])
: minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride ? __half2float(vals[data_id])
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1])
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2])
: minus_infinity;
high_data[i].y = minus_infinity;
if (mask && recompute) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = low_data[i].x / sum;
vals[data_id + 1] = low_data[i].y / sum;
vals[data_id + 2] = high_data[i].x / sum;
vals[data_id + 3] = high_data[i].y / sum;
} else {
vals[data_id] = low_data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum;
}
}
}
}
#endif
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
float scale,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length);
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && recompute) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);
const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
hipLaunchKernelGGL(( attn_softmax_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
vals,
mask,
triangular,
recompute,
local_attention,
window_size,
total_count,
(triangular ? (heads * batch_size) : heads),
sequence_length,
num_seq,
scale,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
float scale,
hipStream_t stream);
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
return std::max(
std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
1);
}
class Context {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
auto message = std::string("Fail to create cublas handle.");
std::cerr << message << std::endl;
throw std::runtime_error(message);
}
cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
cudaEventCreate(&_comp1_event, (cudaEventDisableTiming | cudaEventBlockingSync));
cudaEventCreate(&_comp2_event, (cudaEventDisableTiming | cudaEventBlockingSync));
cudaEventCreate(&_comp_event, (cudaEventDisableTiming | cudaEventBlockingSync));
cudaEventCreate(&_comm_event, (cudaEventDisableTiming | cudaEventBlockingSync));
}
virtual ~Context()
{
cublasDestroy(_cublasHandle);
cudaFree(_workspace);
cudaEventDestroy(_comp1_event);
cudaEventDestroy(_comp2_event);
cudaEventDestroy(_comp_event);
cudaEventDestroy(_comm_event);
}
static Context& Instance()
{
static Context _ctx;
return _ctx;
}
void GenWorkSpace(size_t size)
{
if (!_workspace) {
assert(_workspace == nullptr);
cudaMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
cudaFree(_workspace);
cudaMalloc(&_workspace, size);
}
_workSpaceSize = size;
}
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
size_t get_workspace_size() const { return _workSpaceSize; }
void* GetWorkSpace() { return _workspace; }
inline unsigned new_token(unsigned layer_id)
{
if (layer_id == 0) _token_length++;
return _token_length;
}
inline void reset_tokens(unsigned initial_tokens = 0)
{
_num_tokens = initial_tokens;
} //_token_length = 0; }
inline unsigned current_tokens() const { return _num_tokens; }
inline void advance_tokens() { _num_tokens++; }
curandGenerator_t& GetRandGenerator() { return _gen; }
cudaStream_t GetCommStream(bool async_op = false)
{
if (!_comm_stream)
_comm_stream = async_op ? at::cuda::getStreamFromPool(true)
: at::cuda::getCurrentCUDAStream();
return _comm_stream;
}
cudaStream_t GetCurrentStream(bool other_stream = false)
{
// get current pytorch stream.
if (other_stream) {
if (!_stream) _stream = at::cuda::getStreamFromPool(true);
return _stream;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return stream;
}
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
inline void SynchComp()
{
cudaEventRecord(_comp_event, _comp_stream);
cudaStreamWaitEvent(_comm_stream, _comp_event, 0);
}
inline void SynchComm()
{
cudaEventRecord(_comm_event, _comm_stream);
cudaStreamWaitEvent(_comp_stream, _comm_event, 0);
}
private:
curandGenerator_t _gen;
cublasHandle_t _cublasHandle;
cudaEvent_t _comp_event;
cudaEvent_t _comm_event;
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
cudaEvent_t _comp1_event;
cudaEvent_t _comp2_event;
cudaStream_t _stream;
unsigned _token_length;
unsigned _num_tokens;
std::vector<std::array<int, 3>> _gemm_algos;
cudaStream_t _comp_stream;
cudaStream_t _comm_stream;
std::unordered_map<int, int> _world_sizes;
};
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <ATen/hip/HIPContext.h>
#include <hip/hip_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hiprand/hiprand.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
hipError_t error_code = callstr; \
if (error_code != hipSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
return std::max(
std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
1);
}
class Context {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0), _stream(0)
{
hiprandCreateGenerator(&_gen, HIPRAND_RNG_PSEUDO_DEFAULT);
hiprandSetPseudoRandomGeneratorSeed(_gen, 123);
if (rocblas_create_handle(&_cublasHandle) != rocblas_status_success) {
auto message = std::string("Fail to create cublas handle.");
std::cerr << message << std::endl;
throw std::runtime_error(message);
}
rocblas_set_math_mode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
hipEventCreate(&_comp1_event, (hipEventDisableTiming | hipEventBlockingSync));
hipEventCreate(&_comp2_event, (hipEventDisableTiming | hipEventBlockingSync));
hipEventCreate(&_comp_event, (hipEventDisableTiming | hipEventBlockingSync));
hipEventCreate(&_comm_event, (hipEventDisableTiming | hipEventBlockingSync));
}
virtual ~Context()
{
rocblas_destroy_handle(_cublasHandle);
hipFree(_workspace);
hipEventDestroy(_comp1_event);
hipEventDestroy(_comp2_event);
hipEventDestroy(_comp_event);
hipEventDestroy(_comm_event);
}
static Context& Instance()
{
static Context _ctx;
return _ctx;
}
void GenWorkSpace(size_t size)
{
if (!_workspace) {
assert(_workspace == nullptr);
hipMalloc(&_workspace, size);
} else if (_workSpaceSize < size) {
hipFree(_workspace);
hipMalloc(&_workspace, size);
}
_workSpaceSize = size;
}
hipEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
size_t get_workspace_size() const { return _workSpaceSize; }
void* GetWorkSpace() { return _workspace; }
inline unsigned new_token(unsigned layer_id)
{
if (layer_id == 0) _token_length++;
return _token_length;
}
inline void reset_tokens(unsigned initial_tokens = 0)
{
_num_tokens = initial_tokens;
} //_token_length = 0; }
inline unsigned current_tokens() const { return _num_tokens; }
inline void advance_tokens() { _num_tokens++; }
hiprandGenerator_t& GetRandGenerator() { return _gen; }
hipStream_t GetCommStream(bool async_op = false)
{
if (!_comm_stream)
_comm_stream = async_op ? at::hip::getStreamFromPoolMasqueradingAsCUDA(true)
: at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
return _comm_stream;
}
hipStream_t GetCurrentStream(bool other_stream = false)
{
// get current pytorch stream.
if (other_stream) {
if (!_stream) _stream = at::hip::getStreamFromPoolMasqueradingAsCUDA(true);
return _stream;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
return stream;
}
rocblas_handle GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
inline void SynchComp()
{
hipEventRecord(_comp_event, _comp_stream);
hipStreamWaitEvent(_comm_stream, _comp_event, 0);
}
inline void SynchComm()
{
hipEventRecord(_comm_event, _comm_stream);
hipStreamWaitEvent(_comp_stream, _comm_event, 0);
}
private:
hiprandGenerator_t _gen;
rocblas_handle _cublasHandle;
hipEvent_t _comp_event;
hipEvent_t _comm_event;
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
size_t _workSpaceSize;
hipEvent_t _comp1_event;
hipEvent_t _comp2_event;
hipStream_t _stream;
unsigned _token_length;
unsigned _num_tokens;
std::vector<std::array<int, 3>> _gemm_algos;
hipStream_t _comp_stream;
hipStream_t _comm_stream;
std::unordered_map<int, int> _world_sizes;
};
#pragma once
#include <assert.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdio.h>
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
CUDA_R_32F,
(transa == CUBLAS_OP_N) ? m : k,
(const void*)B,
CUDA_R_32F,
(transb == CUBLAS_OP_N) ? k : n,
(const void*)beta,
C,
CUDA_R_32F,
m,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_gemm_ex(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmEx(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
CUDA_R_16F,
(transa == CUBLAS_OP_N) ? m : k,
(const void*)B,
CUDA_R_16F,
(transb == CUBLAS_OP_N) ? k : n,
(const void*)beta,
(void*)C,
CUDA_R_16F,
m,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasOperation_t op_A,
cublasOperation_t op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
(op_A == CUBLAS_OP_N) ? m : k,
stride_A,
B,
CUDA_R_32F,
(op_B == CUBLAS_OP_N) ? k : n,
stride_B,
beta,
C,
CUDA_R_32F,
m,
stride_C,
batch,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
int cublas_strided_batched_gemm(cublasHandle_t handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasOperation_t op_A,
cublasOperation_t op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
(op_A == CUBLAS_OP_N) ? m : k,
stride_A,
B,
CUDA_R_16F,
(op_B == CUBLAS_OP_N) ? k : n,
stride_B,
beta,
C,
CUDA_R_16F,
m,
stride_C,
batch,
CUDA_R_32F,
algo);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment