Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
#include "memory_access_utils.h" #include "memory_access_utils.h"
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
#include "memory_access_utils.h" #include "memory_access_utils.h"
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <cassert> #include <cassert>
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
......
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a // Copyright (c) Microsoft Corporation.
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp // SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
https:github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
*/
#include <torch/extension.h> #include <torch/extension.h>
#include <string> #include <string>
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <cassert> #include <cassert>
#include "memory_access_utils.h" #include "memory_access_utils.h"
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <torch/extension.h> #include <torch/extension.h>
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
...@@ -8,11 +9,11 @@ Copyright 2022 The Microsoft DeepSpeed Team ...@@ -8,11 +9,11 @@ Copyright 2022 The Microsoft DeepSpeed Team
#define HALF_PRECISION_AVAILABLE = 1 #define HALF_PRECISION_AVAILABLE = 1
#endif #endif
#ifdef __HIPCC__ #ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h> #include <hip/hip_cooperative_groups.h>
#else #else
#include <cooperative_groups.h> #include <cooperative_groups.h>
#endif #endif
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "cublas_wrappers.h" #include "cublas_wrappers.h"
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
...@@ -277,7 +278,7 @@ void launch_dropout(T* out, ...@@ -277,7 +278,7 @@ void launch_dropout(T* out,
grid_dim.x <<= 1; grid_dim.x <<= 1;
} }
uint64_t inc = total_count / grid_dim.x / block_dim.x; uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
if (bwd) if (bwd)
dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>( dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
total_count, ratio, vals, out, mask, seed); total_count, ratio, vals, out, mask, seed);
...@@ -624,7 +625,7 @@ void launch_dropout(T* out, ...@@ -624,7 +625,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS; dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>( dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, bias, out, mask, seed); total_count, dim, ratio, bias, out, mask, seed);
...@@ -846,7 +847,7 @@ void launch_dropout(T* out, ...@@ -846,7 +847,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS; dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>( dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, input, residual, bias, out, mask, seed); total_count, dim, ratio, input, residual, bias, out, mask, seed);
......
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <torch/extension.h> #include <torch/extension.h>
#include <cublas_v2.h> #include <cublas_v2.h>
...@@ -73,8 +78,8 @@ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id, ...@@ -73,8 +78,8 @@ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
_normalize_invertible(normalize_invertible), _normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint), _gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode), _stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()), _stream(TrainingContext::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()), _cublasHandle(TrainingContext::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length, _qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size, 3 * hidden_size,
hidden_size, hidden_size,
...@@ -179,7 +184,7 @@ void BertTransformerLayer<T>::Forward(unsigned bsz, ...@@ -179,7 +184,7 @@ void BertTransformerLayer<T>::Forward(unsigned bsz,
if (!_stochastic_mode) cudaStreamSynchronize(_stream); if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace()); T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size; size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace; T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size; T* buf_1 = buf_0 + small_buf_size;
...@@ -339,7 +344,7 @@ void BertTransformerLayer<T>::Backward(unsigned bsz, ...@@ -339,7 +344,7 @@ void BertTransformerLayer<T>::Backward(unsigned bsz,
if (!_stochastic_mode) cudaStreamSynchronize(_stream); if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace()); T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size; size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace; T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size; T* buf_1 = buf_0 + small_buf_size;
...@@ -605,25 +610,26 @@ int create_transformer_layer(unsigned layer_id, ...@@ -605,25 +610,26 @@ int create_transformer_layer(unsigned layer_id,
bool gelu_checkpoint, bool gelu_checkpoint,
bool stochastic_mode) bool stochastic_mode)
{ {
Context::Instance().SetSeed(seed); TrainingContext::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16( TrainingContext::Instance().TestGemmFP16(
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id, auto layer =
batch_size, std::make_shared<BertTransformerLayer<T>>(layer_id,
hidden_dim, batch_size,
num_heads, hidden_dim,
intermediate_size, num_heads,
init_seq_length, intermediate_size,
attn_dropout_ratio, init_seq_length,
hidden_dropout_ratio, attn_dropout_ratio,
layer_norm_eps, hidden_dropout_ratio,
pre_or_postLayerNorm, layer_norm_eps,
Context::Instance().GetGemmAlgos(), pre_or_postLayerNorm,
attn_dropout_checkpoint, TrainingContext::Instance().GetGemmAlgos(),
normalize_invertible, attn_dropout_checkpoint,
gelu_checkpoint, normalize_invertible,
stochastic_mode); gelu_checkpoint,
stochastic_mode);
s_transformer_layers[layer_id] = layer; s_transformer_layers[layer_id] = layer;
...@@ -721,7 +727,7 @@ std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id, ...@@ -721,7 +727,7 @@ std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
layer->IsTrainingMode(), layer->IsTrainingMode(),
layer->GeluCheckpoint())}, layer->GeluCheckpoint())},
options); options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
...@@ -905,7 +911,7 @@ std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id, ...@@ -905,7 +911,7 @@ std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
layer->IsTrainingMode(), layer->IsTrainingMode(),
layer->GeluCheckpoint())}, layer->GeluCheckpoint())},
options); options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input); auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw); auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "general_kernels.h" #include "general_kernels.h"
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "conversion_utils.h"
#include "inference_cuda_layers.h" #include "inference_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__ #ifndef __HIP_PLATFORM_HCC__
...@@ -11,8 +13,9 @@ Copyright 2022 The Microsoft DeepSpeed Team ...@@ -11,8 +13,9 @@ Copyright 2022 The Microsoft DeepSpeed Team
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query, template <typename T>
float* key_layer, __global__ void apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned rotary_dim, unsigned rotary_dim,
unsigned seq_len, unsigned seq_len,
unsigned seq_offset, unsigned seq_offset,
...@@ -39,8 +42,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, ...@@ -39,8 +42,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
while (lane < rotary_dim) { while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane]; float q = conversion::to<float>(mixed_query[offset + lane]);
float k = key_layer[k_offset + lane]; float k = conversion::to<float>(key_layer[k_offset + lane]);
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign); float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign); float k_rot = (k * rotary_sign);
...@@ -49,59 +52,14 @@ __global__ void apply_rotary_pos_emb(float* mixed_query, ...@@ -49,59 +52,14 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q; mixed_query[offset + lane] = conversion::to<T>(q);
key_layer[k_offset + lane] = k; key_layer[k_offset + lane] = conversion::to<T>(k);
lane += WARP_SIZE; lane += WARP_SIZE;
} }
} }
} }
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(float* mixed_query, __global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer, float* key_layer,
unsigned rotary_dim, unsigned rotary_dim,
...@@ -147,8 +105,10 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query, ...@@ -147,8 +105,10 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
} }
} }
} }
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer, template <typename T>
__global__ void apply_rotary_pos_emb1(T* mixed_query,
T* key_layer,
unsigned rotary_dim, unsigned rotary_dim,
unsigned seq_len, unsigned seq_len,
unsigned seq_offset, unsigned seq_offset,
...@@ -184,8 +144,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, ...@@ -184,8 +144,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
while (lane < rotary_dim) { while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim; float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane]; float q = conversion::to<float>(mixed_query[offset + lane]);
float k = (float)key_layer[k_offset + lane]; float k = conversion::to<float>(key_layer[k_offset + lane]);
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0); float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign); float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign); float k_rot = (k * rotary_sign);
...@@ -196,8 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, ...@@ -196,8 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q; mixed_query[offset + lane] = conversion::to<T>(q);
key_layer[k_offset + lane] = (__half)k; key_layer[k_offset + lane] = conversion::to<T>(k);
lane += WARP_SIZE; lane += WARP_SIZE;
} }
...@@ -255,6 +215,20 @@ template void launch_apply_rotary_pos_emb<float>(float*, ...@@ -255,6 +215,20 @@ template void launch_apply_rotary_pos_emb<float>(float*,
bool, bool,
cudaStream_t, cudaStream_t,
int); int);
#ifdef BF16_AVAILABLE
template void launch_apply_rotary_pos_emb<__nv_bfloat16>(__nv_bfloat16*,
__nv_bfloat16*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
cudaStream_t,
int);
#endif
template void launch_apply_rotary_pos_emb<__half>(__half*, template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*, __half*,
unsigned, unsigned,
...@@ -268,6 +242,59 @@ template void launch_apply_rotary_pos_emb<__half>(__half*, ...@@ -268,6 +242,59 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
cudaStream_t, cudaStream_t,
int); int);
template __global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#ifdef BF16_AVAILABLE
template __global__ void apply_rotary_pos_emb(__nv_bfloat16* mixed_query,
__nv_bfloat16* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#endif
template __global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#ifdef BF16_AVAILABLE
template __global__ void apply_rotary_pos_emb1(__nv_bfloat16* mixed_query,
__nv_bfloat16* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#endif
template __global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
/* /*
__global__ void apply_rotary_pos_emb(float* mixed_query, __global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer, float* key_layer,
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "conversion_utils.h"
#include "inference_cuda_layers.h" #include "inference_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024 #define MAX_QUANTIZE_GROUPING 1024
...@@ -9,7 +11,8 @@ Copyright 2022 The Microsoft DeepSpeed Team ...@@ -9,7 +11,8 @@ Copyright 2022 The Microsoft DeepSpeed Team
#define loop_unroll 1 #define loop_unroll 1
#define loop_unroll_bits 1 #define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output, template <typename T>
__global__ void dequantize_kernel(T* output,
const int8_t* input, const int8_t* input,
const float* qscale, const float* qscale,
int output_size, int output_size,
...@@ -37,40 +40,7 @@ __global__ void dequantize_kernel(float* output, ...@@ -37,40 +40,7 @@ __global__ void dequantize_kernel(float* output,
float scale_data = qscale[scale_index]; float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q); output[q_index] = conversion::to<T>(scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x; tid += blockDim.x;
} }
} }
...@@ -101,6 +71,18 @@ template void launch_dequantize<float>(float*, ...@@ -101,6 +71,18 @@ template void launch_dequantize<float>(float*,
unsigned, unsigned,
unsigned, unsigned,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_dequantize<__nv_bfloat16>(__nv_bfloat16*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
#endif
template void launch_dequantize<__half>(__half*, template void launch_dequantize<__half>(__half*,
const int8_t*, const int8_t*,
const float*, const float*,
...@@ -119,7 +101,8 @@ __global__ void dequantize_kernel(float* output, ...@@ -119,7 +101,8 @@ __global__ void dequantize_kernel(float* output,
{ {
} }
__global__ void dequantize_kernel(__half* output, template <typename T>
__global__ void dequantize_kernel(T* output,
const int8_t* input, const int8_t* input,
const float* qscale, const float* qscale,
unsigned hidden_dim, unsigned hidden_dim,
...@@ -143,12 +126,12 @@ __global__ void dequantize_kernel(__half* output, ...@@ -143,12 +126,12 @@ __global__ void dequantize_kernel(__half* output,
int8_t* q_int8 = (int8_t*)&q; int8_t* q_int8 = (int8_t*)&q;
float2 q_f; float2 q_f;
__half* q_h = (__half*)&q_f; T* q_h = (T*)&q_f;
q_h[0] = __float2half(local_scale * (float)q_int8[0]); q_h[0] = conversion::to<T>(local_scale * (float)q_int8[0]);
q_h[1] = __float2half(local_scale * (float)q_int8[1]); q_h[1] = conversion::to<T>(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]); q_h[2] = conversion::to<T>(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]); q_h[3] = conversion::to<T>(local_scale * (float)q_int8[3]);
output_cast[tid] = q_f; output_cast[tid] = q_f;
tid += blockDim.x; tid += blockDim.x;
} }
...@@ -185,6 +168,17 @@ template void launch_dequantize<float>(float*, ...@@ -185,6 +168,17 @@ template void launch_dequantize<float>(float*,
unsigned, unsigned,
unsigned, unsigned,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_dequantize<__nv_bfloat16>(__nv_bfloat16*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
cudaStream_t);
#endif
template void launch_dequantize<__half>(__half*, template void launch_dequantize<__half>(__half*,
const int8_t*, const int8_t*,
const float*, const float*,
...@@ -192,3 +186,45 @@ template void launch_dequantize<__half>(__half*, ...@@ -192,3 +186,45 @@ template void launch_dequantize<__half>(__half*,
unsigned, unsigned,
unsigned, unsigned,
cudaStream_t); cudaStream_t);
template __global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count);
#ifdef BF16_AVAILABLE
template __global__ void dequantize_kernel(__nv_bfloat16* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count);
#endif
template __global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count);
#ifdef BF16_AVAILABLE
template __global__ void dequantize_kernel(__nv_bfloat16* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt);
#endif
template __global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt);
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "conversion_utils.h" #include "conversion_utils.h"
#include "inference_cuda_layers.h" #include "inference_cuda_layers.h"
...@@ -10,6 +11,11 @@ namespace cg = cooperative_groups; ...@@ -10,6 +11,11 @@ namespace cg = cooperative_groups;
#define MAX_CAP 4 #define MAX_CAP 4
#define MAX_SEQ 2048 #define MAX_SEQ 2048
// only used to avoid compilation error due to lack of definition.
#ifndef BF16_AVAILABLE
using __nv_bfloat162 = __half2;
#endif
inline __device__ float gelu(const float x) inline __device__ float gelu(const float x)
{ {
const float sqrt_param = 0.79788456080286535587989211986876f; const float sqrt_param = 0.79788456080286535587989211986876f;
...@@ -65,6 +71,13 @@ void launch_bias_gelu(T* input, ...@@ -65,6 +71,13 @@ void launch_bias_gelu(T* input,
} }
template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t); template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_gelu<__nv_bfloat16>(__nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t);
/* /*
...@@ -115,6 +128,13 @@ void launch_bias_add(T* input, ...@@ -115,6 +128,13 @@ void launch_bias_add(T* input,
} }
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t); template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_add<__nv_bfloat16>(__nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t);
__global__ void fused_bias_residual(float* residual, __global__ void fused_bias_residual(float* residual,
...@@ -162,16 +182,19 @@ __global__ void fused_bias_residual(float* residual, ...@@ -162,16 +182,19 @@ __global__ void fused_bias_residual(float* residual,
} }
} }
__global__ void fused_bias_residual(__half* residual, template <typename T>
const __half* hidden_state, __global__ void fused_bias_residual(T* residual,
const __half* attn, const T* hidden_state,
const __half* bias, const T* attn,
const __half* attn_bias, const T* bias,
const T* attn_bias,
const int total_count, const int total_count,
const int intermediate_size, const int intermediate_size,
const float mp_scale, const float mp_scale,
const bool preln) const bool preln)
{ {
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual); float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state); const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn); const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
...@@ -186,26 +209,26 @@ __global__ void fused_bias_residual(__half* residual, ...@@ -186,26 +209,26 @@ __global__ void fused_bias_residual(__half* residual,
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2); T2* res_half2 = reinterpret_cast<T2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2); const T2* hs_half2 = reinterpret_cast<const T2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2); const T2* attn_half2 = reinterpret_cast<const T2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2); const T2* bias_half2 = reinterpret_cast<const T2*>(&bias_fl2);
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2); const T2* attn_bias_half2 = reinterpret_cast<const T2*>(&attn_bias_fl2);
float2 res_low = __half22float2(res_half2[0]); float2 res_low = conversion::to<float2>(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]); float2 res_high = conversion::to<float2>(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]); const float2 hs_low = conversion::to<float2>(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]); const float2 hs_high = conversion::to<float2>(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]); const float2 attn_low = conversion::to<float2>(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]); const float2 attn_high = conversion::to<float2>(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]); const float2 bias_low = conversion::to<float2>(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]); const float2 bias_high = conversion::to<float2>(bias_half2[1]);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]); const float2 attn_bias_low = conversion::to<float2>(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]); const float2 attn_bias_high = conversion::to<float2>(attn_bias_half2[1]);
if (preln) { if (preln) {
// residual = (residual + attention + bias + attention_bias) * // residual = (residual + attention + bias + attention_bias) *
...@@ -225,8 +248,8 @@ __global__ void fused_bias_residual(__half* residual, ...@@ -225,8 +248,8 @@ __global__ void fused_bias_residual(__half* residual,
res_high.x = (res_high.x + hs_high.x + bias_high.x); res_high.x = (res_high.x + hs_high.x + bias_high.x);
res_high.y = (res_high.y + hs_high.y + bias_high.y); res_high.y = (res_high.y + hs_high.y + bias_high.y);
} }
res_half2[0] = __float22half2_rn(res_low); res_half2[0] = conversion::to<T2>(res_low);
res_half2[1] = __float22half2_rn(res_high); res_half2[1] = conversion::to<T2>(res_high);
res_fl2_ptr[offset] = res_fl2; res_fl2_ptr[offset] = res_fl2;
} }
...@@ -261,9 +284,43 @@ void launch_bias_residual(T* residual, ...@@ -261,9 +284,43 @@ void launch_bias_residual(T* residual,
template void launch_bias_residual< template void launch_bias_residual<
float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t); float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_residual<__nv_bfloat16>(__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
int,
int,
int,
bool,
cudaStream_t);
#endif
template void launch_bias_residual< template void launch_bias_residual<
__half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t); __half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t);
#ifdef BF16_AVAILABLE
template __global__ void fused_bias_residual(__nv_bfloat16* residual,
const __nv_bfloat16* hidden_state,
const __nv_bfloat16* attn,
const __nv_bfloat16* bias,
const __nv_bfloat16* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln);
#endif
template __global__ void fused_bias_residual(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln);
__global__ void gptj_residual_add(float* residual, __global__ void gptj_residual_add(float* residual,
const float* hidden_state, const float* hidden_state,
const float* attn, const float* attn,
...@@ -304,15 +361,18 @@ __global__ void gptj_residual_add(float* residual, ...@@ -304,15 +361,18 @@ __global__ void gptj_residual_add(float* residual,
} }
} }
__global__ void gptj_residual_add(__half* residual, template <typename T>
const __half* hidden_state, __global__ void gptj_residual_add(T* residual,
const __half* attn, const T* hidden_state,
const __half* bias, const T* attn,
const __half* attn_bias, const T* bias,
const T* attn_bias,
const int total_count, const int total_count,
const int intermediate_size, const int intermediate_size,
const float mp_scale) const float mp_scale)
{ {
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual); float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state); const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn); const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
...@@ -326,28 +386,28 @@ __global__ void gptj_residual_add(__half* residual, ...@@ -326,28 +386,28 @@ __global__ void gptj_residual_add(__half* residual,
const float2 attn_fl2 = attn_fl2_ptr[offset]; const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size]; const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2); T2* res_half2 = reinterpret_cast<T2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2); const T2* hs_half2 = reinterpret_cast<const T2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2); const T2* attn_half2 = reinterpret_cast<const T2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2); const T2* bias_half2 = reinterpret_cast<const T2*>(&bias_fl2);
float2 res_low = __half22float2(res_half2[0]); float2 res_low = conversion::to<float2>(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]); float2 res_high = conversion::to<float2>(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]); const float2 hs_low = conversion::to<float2>(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]); const float2 hs_high = conversion::to<float2>(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]); const float2 attn_low = conversion::to<float2>(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]); const float2 attn_high = conversion::to<float2>(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]); const float2 bias_low = conversion::to<float2>(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]); const float2 bias_high = conversion::to<float2>(bias_half2[1]);
if (attn_bias) { if (attn_bias) {
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size]; const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2); const T2* attn_bias_half2 = reinterpret_cast<const T2*>(&attn_bias_fl2);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]); const float2 attn_bias_low = conversion::to<float2>(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]); const float2 attn_bias_high = conversion::to<float2>(attn_bias_half2[1]);
// residual += attention_bias // residual += attention_bias
res_low.x += attn_bias_low.x; res_low.x += attn_bias_low.x;
res_low.y += attn_bias_low.y; res_low.y += attn_bias_low.y;
...@@ -360,8 +420,8 @@ __global__ void gptj_residual_add(__half* residual, ...@@ -360,8 +420,8 @@ __global__ void gptj_residual_add(__half* residual,
res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale; res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale;
res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale; res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale;
res_half2[0] = __float22half2_rn(res_low); res_half2[0] = conversion::to<T2>(res_low);
res_half2[1] = __float22half2_rn(res_high); res_half2[1] = conversion::to<T2>(res_high);
res_fl2_ptr[offset] = res_fl2; res_fl2_ptr[offset] = res_fl2;
} }
...@@ -395,6 +455,19 @@ template void launch_gptj_residual_add<float>(float*, ...@@ -395,6 +455,19 @@ template void launch_gptj_residual_add<float>(float*,
int, int,
int, int,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_gptj_residual_add<__nv_bfloat16>(__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
int,
int,
int,
cudaStream_t);
#endif
template void launch_gptj_residual_add<__half>(__half*, template void launch_gptj_residual_add<__half>(__half*,
__half*, __half*,
__half*, __half*,
...@@ -404,6 +477,27 @@ template void launch_gptj_residual_add<__half>(__half*, ...@@ -404,6 +477,27 @@ template void launch_gptj_residual_add<__half>(__half*,
int, int,
int, int,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template __global__ void gptj_residual_add(__nv_bfloat16* residual,
const __nv_bfloat16* hidden_state,
const __nv_bfloat16* attn,
const __nv_bfloat16* bias,
const __nv_bfloat16* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale);
#endif
template __global__ void gptj_residual_add(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale);
template <typename T> template <typename T>
__global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim) __global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim)
{ {
...@@ -454,6 +548,16 @@ template void launch_moe_res_matmul(float* residual, ...@@ -454,6 +548,16 @@ template void launch_moe_res_matmul(float* residual,
int seq_len, int seq_len,
int hidden_dim, int hidden_dim,
cudaStream_t stream); cudaStream_t stream);
#ifdef BF16_AVAILABLE
template void launch_moe_res_matmul(__nv_bfloat16* residual,
__nv_bfloat16* coef,
__nv_bfloat16* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream);
#endif
template void launch_moe_res_matmul(__half* residual, template void launch_moe_res_matmul(__half* residual,
__half* coef, __half* coef,
__half* mlp_out, __half* mlp_out,
...@@ -461,11 +565,11 @@ template void launch_moe_res_matmul(__half* residual, ...@@ -461,11 +565,11 @@ template void launch_moe_res_matmul(__half* residual,
int hidden_dim, int hidden_dim,
cudaStream_t stream); cudaStream_t stream);
__global__ void pad_data_kernel(__half* padded_output, template <typename T>
__half* output, __global__ void pad_data_kernel(T* padded_output, T* output, int head_size, int padded_head_size)
int head_size,
int padded_head_size)
{ {
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output); float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output); float4* output_cast = reinterpret_cast<float4*>(output);
int bid = blockIdx.x * (blockDim.y) + threadIdx.y; int bid = blockIdx.x * (blockDim.y) + threadIdx.y;
...@@ -473,8 +577,8 @@ __global__ void pad_data_kernel(__half* padded_output, ...@@ -473,8 +577,8 @@ __global__ void pad_data_kernel(__half* padded_output,
padded_output_cast += (bid * padded_head_size); padded_output_cast += (bid * padded_head_size);
output_cast += (bid * head_size); output_cast += (bid * head_size);
float4 ZERO; float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f); const T2 zero_h = conversion::to<T2>(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); T2* ZERO_h = reinterpret_cast<T2*>(&ZERO);
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
if (idx < head_size) if (idx < head_size)
...@@ -482,12 +586,14 @@ __global__ void pad_data_kernel(__half* padded_output, ...@@ -482,12 +586,14 @@ __global__ void pad_data_kernel(__half* padded_output,
else else
padded_output_cast[idx] = ZERO; padded_output_cast[idx] = ZERO;
} }
__global__ void pad_data_kernel(float* padded_output, __global__ void pad_data_kernel(float* padded_output,
float* output, float* output,
int head_size, int head_size,
int padded_head_size) int padded_head_size)
{ {
} }
template <typename T> template <typename T>
void pad_data(T* padded_output, void pad_data(T* padded_output,
T* output, T* output,
...@@ -507,6 +613,16 @@ template void pad_data(__half* padded_output, ...@@ -507,6 +613,16 @@ template void pad_data(__half* padded_output,
int head_size, int head_size,
int padded_head_size, int padded_head_size,
cudaStream_t stream); cudaStream_t stream);
#ifdef BF16_AVAILABLE
template void pad_data(__nv_bfloat16* padded_output,
__nv_bfloat16* output,
int bsz,
int head_size,
int padded_head_size,
cudaStream_t stream);
#endif
template void pad_data(float* padded_output, template void pad_data(float* padded_output,
float* output, float* output,
int bsz, int bsz,
...@@ -514,13 +630,28 @@ template void pad_data(float* padded_output, ...@@ -514,13 +630,28 @@ template void pad_data(float* padded_output,
int padded_head_size, int padded_head_size,
cudaStream_t stream); cudaStream_t stream);
__global__ void pad_head_seq_kernel(__half* padded_output, #ifdef BF16_AVAILABLE
__half* output, template __global__ void pad_data_kernel(__nv_bfloat16* padded_output,
__nv_bfloat16* output,
int head_size,
int padded_head_size);
#endif
template __global__ void pad_data_kernel(__half* padded_output,
__half* output,
int head_size,
int padded_head_size);
template <typename T>
__global__ void pad_head_seq_kernel(T* padded_output,
T* output,
int seq_len, int seq_len,
int padded_seq_len, int padded_seq_len,
int head_size, int head_size,
int padded_head_size) int padded_head_size)
{ {
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output); float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output); float4* output_cast = reinterpret_cast<float4*>(output);
int bsz = blockIdx.x; int bsz = blockIdx.x;
...@@ -529,8 +660,8 @@ __global__ void pad_head_seq_kernel(__half* padded_output, ...@@ -529,8 +660,8 @@ __global__ void pad_head_seq_kernel(__half* padded_output,
padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size; padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size;
output_cast += (bsz * seq_len + bid) * head_size; output_cast += (bsz * seq_len + bid) * head_size;
float4 ZERO; float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f); const T2 zero_h = conversion::to<T2>(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO); T2* ZERO_h = reinterpret_cast<T2*>(&ZERO);
#pragma unroll #pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h; for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
...@@ -539,6 +670,7 @@ __global__ void pad_head_seq_kernel(__half* padded_output, ...@@ -539,6 +670,7 @@ __global__ void pad_head_seq_kernel(__half* padded_output,
else else
padded_output_cast[idx] = ZERO; padded_output_cast[idx] = ZERO;
} }
__global__ void pad_head_seq_kernel(float* padded_output, __global__ void pad_head_seq_kernel(float* padded_output,
float* output, float* output,
int seq_len, int seq_len,
...@@ -547,6 +679,7 @@ __global__ void pad_head_seq_kernel(float* padded_output, ...@@ -547,6 +679,7 @@ __global__ void pad_head_seq_kernel(float* padded_output,
int padded_head_size) int padded_head_size)
{ {
} }
template <typename T> template <typename T>
void pad_head_seq(T* padded_output, void pad_head_seq(T* padded_output,
T* output, T* output,
...@@ -562,6 +695,7 @@ void pad_head_seq(T* padded_output, ...@@ -562,6 +695,7 @@ void pad_head_seq(T* padded_output,
pad_head_seq_kernel<<<grid_dim, block_dim, 0, stream>>>( pad_head_seq_kernel<<<grid_dim, block_dim, 0, stream>>>(
padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8); padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8);
} }
template void pad_head_seq(__half* padded_output, template void pad_head_seq(__half* padded_output,
__half* output, __half* output,
int bsz, int bsz,
...@@ -570,6 +704,18 @@ template void pad_head_seq(__half* padded_output, ...@@ -570,6 +704,18 @@ template void pad_head_seq(__half* padded_output,
int head_size, int head_size,
int padded_head_size, int padded_head_size,
cudaStream_t stream); cudaStream_t stream);
#ifdef BF16_AVAILABLE
template void pad_head_seq(__nv_bfloat16* padded_output,
__nv_bfloat16* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
cudaStream_t stream);
#endif
template void pad_head_seq(float* padded_output, template void pad_head_seq(float* padded_output,
float* output, float* output,
int bsz, int bsz,
...@@ -680,4 +826,12 @@ template void launch_fused_bias_geglu(__half*, ...@@ -680,4 +826,12 @@ template void launch_fused_bias_geglu(__half*,
int, int,
int, int,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_bias_geglu(__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, cudaStream_t); template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, cudaStream_t);
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "conversion_utils.h" #include "conversion_utils.h"
#include "ds_kernel_utils.h" #include "ds_kernel_utils.h"
...@@ -196,6 +197,16 @@ template void launch_fused_ln(__half*, ...@@ -196,6 +197,16 @@ template void launch_fused_ln(__half*,
int, int,
int, int,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_ln(__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
float,
int,
int,
cudaStream_t);
#endif
template void template void
launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, cudaStream_t); launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, cudaStream_t);
...@@ -492,6 +503,19 @@ template void launch_fused_residual_ln(__half*, ...@@ -492,6 +503,19 @@ template void launch_fused_residual_ln(__half*,
int, int,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_residual_ln(__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
float,
int,
int,
cudaStream_t);
#endif
template void launch_fused_residual_ln(float*, template void launch_fused_residual_ln(float*,
const float*, const float*,
const float*, const float*,
...@@ -516,6 +540,20 @@ template void launch_fused_residual_ln_store_pre_ln_res(__half*, ...@@ -516,6 +540,20 @@ template void launch_fused_residual_ln_store_pre_ln_res(__half*,
int, int,
cudaStream_t); cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_residual_ln_store_pre_ln_res(__nv_bfloat16*,
__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
float,
int,
int,
cudaStream_t);
#endif
template void launch_fused_residual_ln_store_pre_ln_res(float*, template void launch_fused_residual_ln_store_pre_ln_res(float*,
float*, float*,
const float*, const float*,
......
#include <limits>
#include "custom_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -95,7 +96,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, ...@@ -95,7 +96,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
head_offset, head_offset,
mask_stride, mask_stride,
mp_size, mp_size,
Context::Instance().GetCurrentStream(async_op)); InferenceContext::Instance().GetCurrentStream(async_op));
return attn_scores_c; return attn_scores_c;
} }
...@@ -109,18 +110,20 @@ void allocate_workspace(unsigned hidden_dim, ...@@ -109,18 +110,20 @@ void allocate_workspace(unsigned hidden_dim,
unsigned mp_size = 1, unsigned mp_size = 1,
bool external_cache = false, bool external_cache = false,
unsigned rank = 0, unsigned rank = 0,
unsigned max_out_tokens = 1024) unsigned max_out_tokens = 1024,
unsigned min_out_tokens = 1)
{ {
Context::Instance().GenWorkSpace(num_layers, InferenceContext::Instance().GenWorkSpace(num_layers,
num_heads, num_heads,
batch_size, batch_size,
prompt_length, prompt_length,
hidden_dim, hidden_dim,
mp_size, mp_size,
external_cache, external_cache,
sizeof(T), sizeof(T),
rank, rank,
max_out_tokens); max_out_tokens,
min_out_tokens);
} }
template <typename T> template <typename T>
...@@ -131,15 +134,15 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) ...@@ -131,15 +134,15 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
.layout(at::kStrided) .layout(at::kStrided)
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
float alpha = 1; float alpha = 1;
float gemm_beta = 0.0; float gemm_beta = 0.0;
/* /*
// Reallocate memory if we received a new prompt // Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) { if (!workspace || input.size(1) != 1) {
allocate_workspace<T>(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1, allocate_workspace<T>(W.size(1), InferenceContext::Instance().GetMaxTokenLenght(),
head_size); workspace = (T*)Context::Instance().GetWorkSpace(); Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace();
} }
*/ */
...@@ -147,7 +150,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) ...@@ -147,7 +150,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
unsigned m = W.size(1); unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2); unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0); unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
m, m,
...@@ -194,8 +197,9 @@ void attention_unfused(at::Tensor& prev_key_cont, ...@@ -194,8 +197,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
auto mask_stride = get_attn_mask_stride(attn_mask); auto mask_stride = get_attn_mask_stride(attn_mask);
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), InferenceContext::Instance().GetCurrentStream());
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
soft_len, soft_len,
seq_len, seq_len,
k, k,
...@@ -230,9 +234,9 @@ void attention_unfused(at::Tensor& prev_key_cont, ...@@ -230,9 +234,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
0, 0,
mask_stride, mask_stride,
1, 1,
Context::Instance().GetCurrentStream(false)); InferenceContext::Instance().GetCurrentStream(false));
alpha = 1.0; alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
k, k,
seq_len, seq_len,
soft_len, soft_len,
...@@ -363,10 +367,11 @@ void attention_unfused(T* prev_key_cont, ...@@ -363,10 +367,11 @@ void attention_unfused(T* prev_key_cont,
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0; float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
float alpha = norm_factor * norm_factor / layer_scale; float alpha = norm_factor * norm_factor / layer_scale;
float gemm_beta = 0.0; float gemm_beta = 0.0;
T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace(); T* workspace = (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace();
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), InferenceContext::Instance().GetCurrentStream());
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
soft_len, soft_len,
seq_len, seq_len,
k, k,
...@@ -377,7 +382,7 @@ void attention_unfused(T* prev_key_cont, ...@@ -377,7 +382,7 @@ void attention_unfused(T* prev_key_cont,
workspace, workspace,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
Context::Instance().GetMaxTokenLenght() * k, InferenceContext::Instance().GetMaxTokenLenght() * k,
seq_len * k, seq_len * k,
seq_len * soft_len, seq_len * soft_len,
bsz * heads, bsz * heads,
...@@ -399,7 +404,7 @@ void attention_unfused(T* prev_key_cont, ...@@ -399,7 +404,7 @@ void attention_unfused(T* prev_key_cont,
soft_len, soft_len,
heads); heads);
alpha = 1.0; alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
k, k,
seq_len, seq_len,
soft_len, soft_len,
...@@ -410,7 +415,7 @@ void attention_unfused(T* prev_key_cont, ...@@ -410,7 +415,7 @@ void attention_unfused(T* prev_key_cont,
(T*)output, (T*)output,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
Context::Instance().GetMaxTokenLenght() * k, InferenceContext::Instance().GetMaxTokenLenght() * k,
seq_len * soft_len, seq_len * soft_len,
seq_len * k, seq_len * k,
bsz * heads, bsz * heads,
...@@ -421,7 +426,7 @@ void attention_unfused(T* prev_key_cont, ...@@ -421,7 +426,7 @@ void attention_unfused(T* prev_key_cont,
#endif #endif
} }
void reset_cache() { Context::Instance().reset_tokens(); } void reset_cache() { InferenceContext::Instance().reset_tokens(); }
template <typename T> template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value, std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
...@@ -445,8 +450,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value, ...@@ -445,8 +450,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bool is_prompt = (seq_len > 1); bool is_prompt = (seq_len > 1);
if (is_prompt) Context::Instance().reset_tokens(seq_len); if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = Context::Instance().current_tokens(); unsigned soft_len = InferenceContext::Instance().current_tokens();
int k = hidden_dim / heads; int k = hidden_dim / heads;
auto options = at::TensorOptions() auto options = at::TensorOptions()
...@@ -455,16 +460,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value, ...@@ -455,16 +460,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
size_t buf_size = bsz * seq_len * hidden_dim; size_t buf_size = bsz * seq_len * hidden_dim;
auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options);
auto query_cont = workspace + 8 * buf_size; auto query_cont = workspace + 5 * buf_size;
size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) + size_t offset =
layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; 10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLenght()) +
layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
unsigned all_tokens = soft_len; unsigned all_tokens = soft_len;
auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1);
size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim; size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
T* temp_buf = (T*)output.data_ptr() + at::numel(output); T* temp_buf = (T*)output.data_ptr() + at::numel(output);
launch_bias_add_transform_0213<T>((T*)query_cont, launch_bias_add_transform_0213<T>((T*)query_cont,
...@@ -481,9 +487,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value, ...@@ -481,9 +487,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
rotary_dim, rotary_dim,
rotate_half, rotate_half,
rotate_every_two, rotate_every_two,
Context::Instance().GetCurrentStream(), InferenceContext::Instance().GetCurrentStream(),
3, 3,
Context::Instance().GetMaxTokenLenght()); InferenceContext::Instance().GetMaxTokenLenght());
if (rotary_dim > 0 && rotate_half) if (rotary_dim > 0 && rotate_half)
launch_apply_rotary_pos_emb(query_cont, launch_apply_rotary_pos_emb(query_cont,
kv_cache, kv_cache,
...@@ -495,8 +501,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value, ...@@ -495,8 +501,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bsz, bsz,
rotate_half, rotate_half,
rotate_every_two, rotate_every_two,
Context::Instance().GetCurrentStream(), InferenceContext::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght()); InferenceContext::Instance().GetMaxTokenLenght());
attention_unfused<T>(workspace + offset, attention_unfused<T>(workspace + offset,
(T*)query_cont, (T*)query_cont,
...@@ -521,13 +527,27 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value, ...@@ -521,13 +527,27 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
heads, heads,
seq_len, seq_len,
output.size(2), output.size(2),
Context::Instance().GetCurrentStream(false), InferenceContext::Instance().GetCurrentStream(false),
1); 1);
if (layer_id == num_layers - 1) Context::Instance().advance_tokens(); if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset, {bsz, heads, all_tokens, k}, options); auto prev_key = torch::from_blob(workspace + offset,
{bsz, heads, all_tokens, k},
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
k * InferenceContext::Instance().GetMaxTokenLenght(),
k,
1},
options);
auto prev_value = auto prev_value =
torch::from_blob(workspace + offset + value_offset, {bsz, heads, all_tokens, k}, options); torch::from_blob(workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
k * InferenceContext::Instance().GetMaxTokenLenght(),
k,
1},
options);
return {output, prev_key, prev_value}; return {output, prev_key, prev_value};
} }
...@@ -543,7 +563,7 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) ...@@ -543,7 +563,7 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
intermediate_size, intermediate_size,
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return input_cont; return input_cont;
} }
...@@ -569,14 +589,14 @@ at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias) ...@@ -569,14 +589,14 @@ at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias)
(const float*)bias.data_ptr(), (const float*)bias.data_ptr(),
rows, rows,
channels, channels,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} else { } else {
launch_fused_bias_geglu((__half*)output.data_ptr(), launch_fused_bias_geglu((__half*)output.data_ptr(),
(const __half*)activation.data_ptr(), (const __half*)activation.data_ptr(),
(const __half*)bias.data_ptr(), (const __half*)bias.data_ptr(),
rows, rows,
channels, channels,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} }
return output; return output;
...@@ -594,7 +614,7 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias) ...@@ -594,7 +614,7 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
intermediate_size, intermediate_size,
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return input_cont; return input_cont;
} }
...@@ -610,7 +630,7 @@ at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias) ...@@ -610,7 +630,7 @@ at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
hidden_size, hidden_size,
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return input_cont; return input_cont;
} }
...@@ -627,7 +647,7 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& ...@@ -627,7 +647,7 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor&
// bsz, // bsz,
// input_cont.size(2), // input_cont.size(2),
// (bias.size(0) > 1), // (bias.size(0) > 1),
// Context::Instance().GetCurrentStream()); // InferenceContext::Instance().GetCurrentStream());
return input_cont; return input_cont;
} }
...@@ -645,7 +665,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, ...@@ -645,7 +665,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
epsilon, epsilon,
rows, rows,
elems_per_row, elems_per_row,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} else { } else {
launch_fused_ln((float*)output.data_ptr(), launch_fused_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(), (const float*)input.data_ptr(),
...@@ -654,7 +674,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, ...@@ -654,7 +674,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
epsilon, epsilon,
rows, rows,
elems_per_row, elems_per_row,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} }
return output; return output;
...@@ -675,7 +695,7 @@ void ds_layer_norm_internal(T* workspace, ...@@ -675,7 +695,7 @@ void ds_layer_norm_internal(T* workspace,
epsilon, epsilon,
bsz, bsz,
input.size(2), input.size(2),
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} }
/* Currently only used in unit testing */ /* Currently only used in unit testing */
...@@ -700,7 +720,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input, ...@@ -700,7 +720,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
epsilon, epsilon,
rows, rows,
elems_per_row, elems_per_row,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} else { } else {
launch_fused_residual_ln((float*)output.data_ptr(), launch_fused_residual_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(), (const float*)input.data_ptr(),
...@@ -711,7 +731,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input, ...@@ -711,7 +731,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
epsilon, epsilon,
rows, rows,
elems_per_row, elems_per_row,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} }
return output; return output;
...@@ -741,7 +761,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu ...@@ -741,7 +761,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
epsilon, epsilon,
rows, rows,
elems_per_row, elems_per_row,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} else { } else {
launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(), launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(),
(float*)res_output.data_ptr(), (float*)res_output.data_ptr(),
...@@ -753,7 +773,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu ...@@ -753,7 +773,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
epsilon, epsilon,
rows, rows,
elems_per_row, elems_per_row,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} }
return {norm_output, res_output}; return {norm_output, res_output};
...@@ -768,7 +788,7 @@ void quantized_gemm(void* output, ...@@ -768,7 +788,7 @@ void quantized_gemm(void* output,
int bsz, int bsz,
int hidden_size) int hidden_size)
{ {
// T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz; // T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
auto options = at::TensorOptions() auto options = at::TensorOptions()
.dtype(at::kHalf) .dtype(at::kHalf)
...@@ -783,11 +803,11 @@ void quantized_gemm(void* output, ...@@ -783,11 +803,11 @@ void quantized_gemm(void* output,
weight.size(0), weight.size(0),
weight.size(1), weight.size(1),
groups, groups,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
weight.size(0), weight.size(0),
...@@ -815,10 +835,11 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ...@@ -815,10 +835,11 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& beta, at::Tensor& beta,
const float epsilon, const float epsilon,
bool add_bias, bool add_bias,
bool q_int8) bool q_int8,
bool transposed_mode)
{ {
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
workspace += (3 * bsz * input.size(2)); workspace += (3 * bsz * input.size(2));
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon); ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
...@@ -829,12 +850,12 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ...@@ -829,12 +850,12 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N, (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N, CUBLAS_OP_N,
weight.size(1), weight.size(transposed_mode ? 0 : 1),
bsz, bsz,
input.size(2), input.size(2),
&alpha, &alpha,
...@@ -851,9 +872,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ...@@ -851,9 +872,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
if (add_bias) if (add_bias)
launch_bias_add((T*)output.data_ptr(), launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1), (transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return torch::from_blob(workspace, input.sizes(), input.options()); return torch::from_blob(workspace, input.sizes(), input.options());
} }
...@@ -870,11 +891,12 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input, ...@@ -870,11 +891,12 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
bool external_cache, bool external_cache,
unsigned mp_size, unsigned mp_size,
unsigned rank, unsigned rank,
bool q_int8) bool q_int8,
bool transposed_mode)
{ {
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
int out_size = q_int8 ? weight.size(0) : weight.size(1); int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
auto options = at::TensorOptions() auto options = at::TensorOptions()
.dtype(input.options().dtype()) .dtype(input.options().dtype())
...@@ -883,8 +905,17 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input, ...@@ -883,8 +905,17 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
.requires_grad(false); .requires_grad(false);
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
auto inp_norm = qkv_unfused_cublas<T>( auto inp_norm = qkv_unfused_cublas<T>(output,
output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8); input,
weight,
q_scale,
bias,
gamma,
beta,
epsilon,
add_bias,
q_int8,
transposed_mode);
return {output, inp_norm}; return {output, inp_norm};
} }
...@@ -912,11 +943,11 @@ void quantized_gemm(at::Tensor& output, ...@@ -912,11 +943,11 @@ void quantized_gemm(at::Tensor& output,
weight.size(1), weight.size(1),
groups, groups,
merge_count, merge_count,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
weight.size(0), weight.size(0),
...@@ -963,7 +994,7 @@ at::Tensor ds_qkv_gemm_int8(at::Tensor& input, ...@@ -963,7 +994,7 @@ at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
weight.size(1), weight.size(1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return output; return output;
} }
...@@ -974,7 +1005,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, ...@@ -974,7 +1005,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
at::Tensor& bias, at::Tensor& bias,
bool add_bias, bool add_bias,
bool do_flash_attn, bool do_flash_attn,
int num_heads) int num_heads,
bool transposed_mode)
{ {
auto input_cont = input.contiguous(); auto input_cont = input.contiguous();
auto options = at::TensorOptions() auto options = at::TensorOptions()
...@@ -985,17 +1017,18 @@ at::Tensor ds_linear_layer(at::Tensor& input, ...@@ -985,17 +1017,18 @@ at::Tensor ds_linear_layer(at::Tensor& input,
int head_size = input_cont.size(2) / num_heads; int head_size = input_cont.size(2) / num_heads;
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N, (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N, CUBLAS_OP_N,
weight.size(1), weight.size(transposed_mode ? 0 : 1),
bsz, bsz,
input_cont.size(2), input_cont.size(2),
&alpha, &alpha,
...@@ -1011,9 +1044,9 @@ at::Tensor ds_linear_layer(at::Tensor& input, ...@@ -1011,9 +1044,9 @@ at::Tensor ds_linear_layer(at::Tensor& input,
if (add_bias) if (add_bias)
launch_bias_add((T*)output.data_ptr(), launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
weight.size(1), weight.size(transposed_mode ? 0 : 1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0); bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0);
if (do_flash_attn) { if (do_flash_attn) {
if (add_padding) { if (add_padding) {
...@@ -1026,7 +1059,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, ...@@ -1026,7 +1059,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
3 * bsz * num_heads, 3 * bsz * num_heads,
head_size, head_size,
padded_head_size, padded_head_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
launch_bias_add_transform_0213<T>( launch_bias_add_transform_0213<T>(
final_output, final_output,
...@@ -1043,7 +1076,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, ...@@ -1043,7 +1076,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
-1, -1,
false, false,
false, false,
Context::Instance().GetCurrentStream(), InferenceContext::Instance().GetCurrentStream(),
3, 3,
input.size(1)); input.size(1));
return at::from_blob(final_output, return at::from_blob(final_output,
...@@ -1068,7 +1101,7 @@ at::Tensor ds_linear_layer(at::Tensor& input, ...@@ -1068,7 +1101,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
-1, -1,
false, false,
false, false,
Context::Instance().GetCurrentStream(), InferenceContext::Instance().GetCurrentStream(),
3, 3,
input.size(1)); input.size(1));
return at::from_blob( return at::from_blob(
...@@ -1086,7 +1119,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens ...@@ -1086,7 +1119,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
{ {
int head_size = query.size(3); int head_size = query.size(3);
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128); int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2); T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128; T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128;
pad_head_seq(workspace, pad_head_seq(workspace,
...@@ -1096,7 +1129,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens ...@@ -1096,7 +1129,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
query.size(2), query.size(2),
head_size, head_size,
padded_head_size, padded_head_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
pad_head_seq(key_pad_ptr, pad_head_seq(key_pad_ptr,
(T*)key.data_ptr(), (T*)key.data_ptr(),
query.size(0) * query.size(1), query.size(0) * query.size(1),
...@@ -1104,7 +1137,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens ...@@ -1104,7 +1137,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
128, 128,
head_size, head_size,
padded_head_size, padded_head_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
pad_head_seq(value_pad_ptr, pad_head_seq(value_pad_ptr,
(T*)value.data_ptr(), (T*)value.data_ptr(),
query.size(0) * query.size(1), query.size(0) * query.size(1),
...@@ -1112,7 +1145,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens ...@@ -1112,7 +1145,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
128, 128,
head_size, head_size,
padded_head_size, padded_head_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return { return {
at::from_blob(workspace, at::from_blob(workspace,
{query.size(0), query.size(1), query.size(2), padded_head_size}, {query.size(0), query.size(1), query.size(2), padded_head_size},
...@@ -1134,7 +1167,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query, ...@@ -1134,7 +1167,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
int key_value_length = add_padding ? 128 : key.size(1); int key_value_length = add_padding ? 128 : key.size(1);
int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128)) int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128))
: head_size; : head_size;
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1); T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length; T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length;
launch_pad_add_transform_0213(workspace, launch_pad_add_transform_0213(workspace,
...@@ -1145,7 +1178,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query, ...@@ -1145,7 +1178,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
query.size(1), query.size(1),
heads, heads,
padded_head_size, padded_head_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
launch_pad_add_transform_0213(key_pad_ptr, launch_pad_add_transform_0213(key_pad_ptr,
(T*)key.data_ptr(), (T*)key.data_ptr(),
key.size(0), key.size(0),
...@@ -1154,7 +1187,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query, ...@@ -1154,7 +1187,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
key_value_length, key_value_length,
heads, heads,
padded_head_size, padded_head_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
launch_pad_add_transform_0213(value_pad_ptr, launch_pad_add_transform_0213(value_pad_ptr,
(T*)value.data_ptr(), (T*)value.data_ptr(),
value.size(0), value.size(0),
...@@ -1163,7 +1196,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query, ...@@ -1163,7 +1196,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
key_value_length, key_value_length,
heads, heads,
padded_head_size, padded_head_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return { return {
at::from_blob( at::from_blob(
workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()), workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()),
...@@ -1196,7 +1229,7 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input, ...@@ -1196,7 +1229,7 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input,
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
weight.size(1), weight.size(1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return output; return output;
} }
...@@ -1205,7 +1238,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input, ...@@ -1205,7 +1238,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
at::Tensor& weight, at::Tensor& weight,
bool async_op, bool async_op,
at::Tensor& q_scale, at::Tensor& q_scale,
bool q_int8) bool q_int8,
bool transposed_mode)
{ {
auto options = at::TensorOptions() auto options = at::TensorOptions()
.dtype(input.options().dtype()) .dtype(input.options().dtype())
...@@ -1215,7 +1249,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input, ...@@ -1215,7 +1249,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
int out_size = q_int8 ? weight.size(0) : weight.size(1); int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) { if (q_int8) {
quantized_gemm<T>(output.data_ptr(), quantized_gemm<T>(output.data_ptr(),
...@@ -1228,12 +1262,12 @@ at::Tensor ds_vector_matmul(at::Tensor& input, ...@@ -1228,12 +1262,12 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
} else { } else {
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op)); InferenceContext::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N, (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N, CUBLAS_OP_N,
weight.size(1), weight.size(transposed_mode ? 0 : 1),
bsz, bsz,
input.size(2), input.size(2),
&alpha, &alpha,
...@@ -1286,11 +1320,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ...@@ -1286,11 +1320,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
at::Tensor& q_scale, at::Tensor& q_scale,
at::Tensor& q_scale1, at::Tensor& q_scale1,
bool q_int8, bool q_int8,
ActivationFuncType act_func_type) ActivationFuncType act_func_type,
bool transposed_mode)
{ {
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
T* inp_norm = T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input) +
(T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output); torch::numel(output);
T* intermediate = inp_norm + torch::numel(input); T* intermediate = inp_norm + torch::numel(input);
if (mlp_after_attn) { if (mlp_after_attn) {
...@@ -1303,7 +1338,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ...@@ -1303,7 +1338,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
epsilon, epsilon,
bsz, bsz,
input.size(2), input.size(2),
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} else { } else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
} }
...@@ -1313,12 +1348,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ...@@ -1313,12 +1348,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else { } else {
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, weight.size(transposed_mode ? 0 : 1),
weight.size(1),
bsz, bsz,
input.size(2), input.size(2),
&alpha, &alpha,
...@@ -1335,15 +1370,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ...@@ -1335,15 +1370,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
if (act_func_type == ActivationFuncType::GELU) { if (act_func_type == ActivationFuncType::GELU) {
launch_bias_gelu(intermediate, launch_bias_gelu(intermediate,
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1), (transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} else if (act_func_type == ActivationFuncType::ReLU) { } else if (act_func_type == ActivationFuncType::ReLU) {
launch_bias_relu(intermediate, launch_bias_relu(intermediate,
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1), (transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
} }
if (q_int8) { if (q_int8) {
...@@ -1357,14 +1392,14 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ...@@ -1357,14 +1392,14 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else { } else {
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N, (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N, CUBLAS_OP_N,
weight1.size(1), weight1.size(transposed_mode ? 0 : 1),
bsz, bsz,
weight1.size(0), weight1.size(transposed_mode ? 1 : 0),
&alpha, &alpha,
&gemm_beta, &gemm_beta,
(T*)weight1.data_ptr(), (T*)weight1.data_ptr(),
...@@ -1395,7 +1430,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input, ...@@ -1395,7 +1430,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& q_scale, at::Tensor& q_scale,
at::Tensor& q_scale1, at::Tensor& q_scale1,
bool q_int8, bool q_int8,
int activation_type) int activation_type,
bool transposed_mode)
{ {
auto options = at::TensorOptions() auto options = at::TensorOptions()
.dtype(input.options().dtype()) .dtype(input.options().dtype())
...@@ -1403,10 +1439,11 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input, ...@@ -1403,10 +1439,11 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); int out_size = (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), auto output =
{input.size(0), input.size(1), out_size}, at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input),
options); {input.size(0), input.size(1), out_size},
options);
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
auto act_func_type = static_cast<ActivationFuncType>(activation_type); auto act_func_type = static_cast<ActivationFuncType>(activation_type);
...@@ -1425,7 +1462,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input, ...@@ -1425,7 +1462,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
q_scale, q_scale,
q_scale1, q_scale1,
q_int8, q_int8,
act_func_type); act_func_type,
transposed_mode);
return {output, res_add}; return {output, res_add};
} }
...@@ -1461,7 +1499,7 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input, ...@@ -1461,7 +1499,7 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
weight.size(1), weight.size(1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return {output, residual_add}; return {output, residual_add};
} }
...@@ -1476,7 +1514,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, ...@@ -1476,7 +1514,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
const float epsilon, const float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool q_int8, bool q_int8,
bool async_op) bool async_op,
bool transposed_mode)
{ {
auto options = at::TensorOptions() auto options = at::TensorOptions()
.dtype(input.options().dtype()) .dtype(input.options().dtype())
...@@ -1484,9 +1523,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, ...@@ -1484,9 +1523,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
int intm_dim = q_int8 ? weight.size(0) : weight.size(1); int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input), // auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() +
// torch::numel(input),
// {input.size(0), input.size(1), out_size}, // {input.size(0), input.size(1), out_size},
// options); // options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input); // T* intermediate = (T*)input.data_ptr() + torch::numel(input);
...@@ -1505,10 +1545,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, ...@@ -1505,10 +1545,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
bsz, bsz,
input.size(2)); input.size(2));
} else { } else {
cublasSetStream(Context::Instance().GetCublasHandle(), cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N, (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N, CUBLAS_OP_N,
intm_dim, intm_dim,
bsz, bsz,
...@@ -1528,9 +1568,9 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, ...@@ -1528,9 +1568,9 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
intm_dim, intm_dim,
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1); int out_size = (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1);
auto output = at::empty({input.size(0), input.size(1), out_size}, options); auto output = at::empty({input.size(0), input.size(1), out_size}, options);
if (q_int8) { if (q_int8) {
quantized_gemm<T>(output.data_ptr(), quantized_gemm<T>(output.data_ptr(),
...@@ -1541,8 +1581,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, ...@@ -1541,8 +1581,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
bsz, bsz,
input.size(2)); input.size(2));
} else { } else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(), cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N, (transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N, CUBLAS_OP_N,
out_size, out_size,
bsz, bsz,
...@@ -1558,8 +1598,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, ...@@ -1558,8 +1598,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif #endif
} }
// cudaEventRecord(Context::Instance().GetCompEvent(2), // cudaEventRecord(InferenceContext::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true)); // InferenceContext::Instance().GetCurrentStream(true));
return output; return output;
} }
...@@ -1586,7 +1626,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, ...@@ -1586,7 +1626,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
hidden_size, hidden_size,
mp_size, mp_size,
preln, preln,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
else else
launch_gptj_residual_add<T>( launch_gptj_residual_add<T>(
static_cast<T*>(residual.data_ptr()), static_cast<T*>(residual.data_ptr()),
...@@ -1597,7 +1637,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state, ...@@ -1597,7 +1637,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
hidden_size, hidden_size,
bsz, bsz,
mp_size, mp_size,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return residual; return residual;
} }
...@@ -1627,8 +1667,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query, ...@@ -1627,8 +1667,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz, bsz,
rotate_half, rotate_half,
rotate_every_two, rotate_every_two,
Context::Instance().GetCurrentStream(), InferenceContext::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght()); InferenceContext::Instance().GetMaxTokenLenght());
else else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(), (__half*)key_cont.data_ptr(),
...@@ -1640,8 +1680,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query, ...@@ -1640,8 +1680,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz, bsz,
rotate_half, rotate_half,
rotate_every_two, rotate_every_two,
Context::Instance().GetCurrentStream(), InferenceContext::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght()); InferenceContext::Instance().GetMaxTokenLenght());
return {query_cont, key_cont}; return {query_cont, key_cont};
} }
...@@ -1670,7 +1710,7 @@ at::Tensor fused_gemm_gelu_int8(at::Tensor& input, ...@@ -1670,7 +1710,7 @@ at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
weight.size(1), weight.size(1),
bsz, bsz,
Context::Instance().GetCurrentStream()); InferenceContext::Instance().GetCurrentStream());
return output; return output;
} }
...@@ -1679,7 +1719,7 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out ...@@ -1679,7 +1719,7 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
{ {
int M = moe_res.size(0) * moe_res.size(1); int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2); int N = moe_res.size(2);
Context::Instance().SynchComm(); InferenceContext::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) { if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(), launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(), (float*)coef.data_ptr(),
...@@ -1698,83 +1738,77 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out ...@@ -1698,83 +1738,77 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
return output; return output;
} }
void ds_release_workspace() { InferenceContext::Instance().release_workspace(); }
bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp16 (CUDA)");
m.def("softmax_context_int8", m.def("softmax_context_int8",
&ds_softmax_context1<__half>, &ds_softmax_context1<__half>,
"DeepSpeed attention with int8 (CUDA)"); "DeepSpeed attention with int8 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_geglu", &ds_bias_geglu, "DeepSpeed Bias GEGLU (CUDA)"); m.def("bias_geglu", &ds_bias_geglu, "DeepSpeed Bias GEGLU (CUDA)");
m.def("bias_add_fp32", &ds_bias_add<float>, "DeepSpeed Bias Add with fp32 (CUDA)");
m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_relu_fp32", &ds_bias_relu<float>, "DeepSpeed ReLU with fp32 (CUDA)");
m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp16 (CUDA)");
m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)"); m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)");
m.def( m.def(
"_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)"); "_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)");
m.def("layer_norm_residual_store_pre_ln_res", m.def("layer_norm_residual_store_pre_ln_res",
&ds_layer_norm_residual_store_pre_ln_res, &ds_layer_norm_residual_store_pre_ln_res,
"DeepSpeed layer norm + store pre Layernorm residual (CUDA)"); "DeepSpeed layer norm + store pre Layernorm residual (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)"); m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
m.def("mlp_gemm_fp32", &ds_mlp_gemm<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)"); m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)");
m.def("vector_matmul_fp32", &ds_vector_matmul<float>, "DeepSpeed vector-MM with fp32 (CUDA)");
m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)");
m.def("vector_matmul_int8", m.def("vector_matmul_int8",
&ds_vector_matmul_int8<__half>, &ds_vector_matmul_int8<__half>,
"DeepSpeed vector-MM with int8 (CUDA)"); "DeepSpeed vector-MM with int8 (CUDA)");
m.def("linear_layer_fp32", &ds_linear_layer<float>, "DeepSpeed linear_layer with fp32 (CUDA)");
m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)");
m.def("linear_layer_int8", m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>, &ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)"); "DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add_bias_fp32",
&residual_add_bias<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("residual_add_bias_fp16",
&residual_add_bias<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
"DeepSpeed vector-MM with fp32 (CUDA)");
m.def("einsum_sec_sm_ecm_fp16",
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)"); m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
m.def("add_padding_fp32", &add_padding<float>, "DeepSpeed residual add with fp32 (CUDA)");
m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)");
m.def("pad_transform_fp32",
&padd_add_transform<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("pad_transform_fp16",
&padd_add_transform<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("allocate_workspace_fp32",
&allocate_workspace<float>,
"DeepSpeed memory allocation for GPT inference with fp32 (CUDA)");
m.def("allocate_workspace_fp16",
&allocate_workspace<__half>,
"DeepSpeed memory allocation for GPT inference with fp16 (CUDA)");
m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks"); m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks");
m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace");
m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace");
#define DEF_OPS(_name, _dtype) \
m.def("softmax_" #_name, &ds_softmax<_dtype>, "DeepSpeed SoftMax with " #_name " (CUDA)"); \
m.def("softmax_context_" #_name, \
&ds_softmax_context<_dtype>, \
"DeepSpeed attention with _name (CUDA)"); \
m.def("bias_gelu_" #_name, &ds_bias_gelu<_dtype>, "DeepSpeed Gelu with " #_name " (CUDA)"); \
m.def("bias_add_" #_name, &ds_bias_add<_dtype>, "DeepSpeed Bias Add with " #_name " (CUDA)"); \
m.def("bias_relu_" #_name, &ds_bias_relu<_dtype>, "DeepSpeed ReLU with " #_name " (CUDA)"); \
m.def("bias_residual_" #_name, \
&ds_bias_residual<_dtype>, \
"DeepSpeed residual-bias add with " #_name " (CUDA)"); \
m.def("qkv_gemm_" #_name, &ds_qkv_gemm<_dtype>, "DeepSpeed qkv gemm with " #_name " (CUDA)"); \
m.def("mlp_gemm_" #_name, &ds_mlp_gemm<_dtype>, "DeepSpeed mlp with " #_name " (CUDA)"); \
m.def("vector_matmul_" #_name, \
&ds_vector_matmul<_dtype>, \
"DeepSpeed vector-MM with " #_name " (CUDA)"); \
m.def("linear_layer_" #_name, \
&ds_linear_layer<_dtype>, \
"DeepSpeed linear_layer with " #_name " (CUDA)"); \
m.def("fused_gemm_gelu_" #_name, \
&fused_gemm_gelu<_dtype>, \
"DeepSpeed mlp with " #_name " (CUDA)"); \
m.def("residual_add_bias_" #_name, \
&residual_add_bias<_dtype>, \
"DeepSpeed residual add with " #_name " (CUDA)"); \
m.def("einsum_sec_sm_ecm_" #_name, \
&einsum_sec_sm_ecm<_dtype>, \
"DeepSpeed vector-MM with " #_name " (CUDA)"); \
m.def("add_padding_" #_name, \
&add_padding<_dtype>, \
"DeepSpeed residual add with " #_name " (CUDA)"); \
m.def("pad_transform_" #_name, \
&padd_add_transform<_dtype>, \
"DeepSpeed residual add with " #_name " (CUDA)"); \
m.def("allocate_workspace_" #_name, \
&allocate_workspace<_dtype>, \
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)")
DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
#ifdef BF16_AVAILABLE
DEF_OPS(bf16, __nv_bfloat16);
#endif
} }
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "conversion_utils.h" #include "conversion_utils.h"
#include "inference_cuda_layers.h" #include "inference_cuda_layers.h"
...@@ -60,4 +61,11 @@ void launch_bias_relu(T* input, ...@@ -60,4 +61,11 @@ void launch_bias_relu(T* input,
} }
template void launch_bias_relu<float>(float*, const float*, int, int, cudaStream_t); template void launch_bias_relu<float>(float*, const float*, int, int, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_relu<__nv_bfloat16>(__nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_bias_relu<__half>(__half*, const __half*, int, int, cudaStream_t); template void launch_bias_relu<__half>(__half*, const __half*, int, int, cudaStream_t);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment