"src/array/cpu/coo_coalesce.cc" did not exist on "87bca129a48f3b4c9248577ac8273a4a8e6cdf40"
Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <torch/extension.h>
#include <vector>
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <cassert>
#include "custom_cuda_layers.h"
......
// DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
// https://github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
/*
DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
https:github.com/ptillet/torch-blocksparse/blob/master/csrc/utils.cpp
*/
#include <torch/extension.h>
#include <string>
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <cassert>
#include "memory_access_utils.h"
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#pragma once
......@@ -8,11 +9,11 @@ Copyright 2022 The Microsoft DeepSpeed Team
#define HALF_PRECISION_AVAILABLE = 1
#endif
#ifdef __HIPCC__
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#endif
#include <cuda.h>
#include <cuda_fp16.h>
......
/*
Copyright The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "cublas_wrappers.h"
......
/*
Copyright The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "custom_cuda_layers.h"
......@@ -277,7 +278,7 @@ void launch_dropout(T* out,
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
if (bwd)
dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
total_count, ratio, vals, out, mask, seed);
......@@ -624,7 +625,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, bias, out, mask, seed);
......@@ -846,7 +847,7 @@ void launch_dropout(T* out,
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
total_count, dim, ratio, input, residual, bias, out, mask, seed);
......
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <torch/extension.h>
#include <cublas_v2.h>
......@@ -73,8 +78,8 @@ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
_normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()),
_stream(TrainingContext::Instance().GetCurrentStream()),
_cublasHandle(TrainingContext::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size,
hidden_size,
......@@ -179,7 +184,7 @@ void BertTransformerLayer<T>::Forward(unsigned bsz,
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
......@@ -339,7 +344,7 @@ void BertTransformerLayer<T>::Backward(unsigned bsz,
if (!_stochastic_mode) cudaStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
......@@ -605,25 +610,26 @@ int create_transformer_layer(unsigned layer_id,
bool gelu_checkpoint,
bool stochastic_mode)
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
TrainingContext::Instance().SetSeed(seed);
TrainingContext::Instance().TestGemmFP16(
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
auto layer =
std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
TrainingContext::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
s_transformer_layers[layer_id] = layer;
......@@ -721,7 +727,7 @@ std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
......@@ -905,7 +911,7 @@ std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
......
/*
Copyright The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "custom_cuda_layers.h"
......
/*
Copyright The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "general_kernels.h"
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
......@@ -11,8 +13,9 @@ Copyright 2022 The Microsoft DeepSpeed Team
namespace cg = cooperative_groups;
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
template <typename T>
__global__ void apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
......@@ -39,8 +42,8 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[k_offset + lane];
float q = conversion::to<float>(mixed_query[offset + lane]);
float k = conversion::to<float>(key_layer[k_offset + lane]);
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
......@@ -49,59 +52,14 @@ __global__ void apply_rotary_pos_emb(float* mixed_query,
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[k_offset + lane] = k;
mixed_query[offset + lane] = conversion::to<T>(q);
key_layer[k_offset + lane] = conversion::to<T>(k);
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
......@@ -147,8 +105,10 @@ __global__ void apply_rotary_pos_emb1(float* mixed_query,
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
template <typename T>
__global__ void apply_rotary_pos_emb1(T* mixed_query,
T* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
......@@ -184,8 +144,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float q = conversion::to<float>(mixed_query[offset + lane]);
float k = conversion::to<float>(key_layer[k_offset + lane]);
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
......@@ -196,8 +156,8 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query,
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
mixed_query[offset + lane] = conversion::to<T>(q);
key_layer[k_offset + lane] = conversion::to<T>(k);
lane += WARP_SIZE;
}
......@@ -255,6 +215,20 @@ template void launch_apply_rotary_pos_emb<float>(float*,
bool,
cudaStream_t,
int);
#ifdef BF16_AVAILABLE
template void launch_apply_rotary_pos_emb<__nv_bfloat16>(__nv_bfloat16*,
__nv_bfloat16*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
cudaStream_t,
int);
#endif
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
......@@ -268,6 +242,59 @@ template void launch_apply_rotary_pos_emb<__half>(__half*,
cudaStream_t,
int);
template __global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#ifdef BF16_AVAILABLE
template __global__ void apply_rotary_pos_emb(__nv_bfloat16* mixed_query,
__nv_bfloat16* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#endif
template __global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#ifdef BF16_AVAILABLE
template __global__ void apply_rotary_pos_emb1(__nv_bfloat16* mixed_query,
__nv_bfloat16* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
#endif
template __global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
......@@ -9,7 +11,8 @@ Copyright 2022 The Microsoft DeepSpeed Team
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
template <typename T>
__global__ void dequantize_kernel(T* output,
const int8_t* input,
const float* qscale,
int output_size,
......@@ -37,40 +40,7 @@ __global__ void dequantize_kernel(float* output,
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
output[q_index] = conversion::to<T>(scale_data * (float)q);
tid += blockDim.x;
}
}
......@@ -101,6 +71,18 @@ template void launch_dequantize<float>(float*,
unsigned,
unsigned,
cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_dequantize<__nv_bfloat16>(__nv_bfloat16*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
#endif
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
......@@ -119,7 +101,8 @@ __global__ void dequantize_kernel(float* output,
{
}
__global__ void dequantize_kernel(__half* output,
template <typename T>
__global__ void dequantize_kernel(T* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
......@@ -143,12 +126,12 @@ __global__ void dequantize_kernel(__half* output,
int8_t* q_int8 = (int8_t*)&q;
float2 q_f;
__half* q_h = (__half*)&q_f;
T* q_h = (T*)&q_f;
q_h[0] = __float2half(local_scale * (float)q_int8[0]);
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
q_h[0] = conversion::to<T>(local_scale * (float)q_int8[0]);
q_h[1] = conversion::to<T>(local_scale * (float)q_int8[1]);
q_h[2] = conversion::to<T>(local_scale * (float)q_int8[2]);
q_h[3] = conversion::to<T>(local_scale * (float)q_int8[3]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
......@@ -185,6 +168,17 @@ template void launch_dequantize<float>(float*,
unsigned,
unsigned,
cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_dequantize<__nv_bfloat16>(__nv_bfloat16*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
cudaStream_t);
#endif
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
......@@ -192,3 +186,45 @@ template void launch_dequantize<__half>(__half*,
unsigned,
unsigned,
cudaStream_t);
template __global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count);
#ifdef BF16_AVAILABLE
template __global__ void dequantize_kernel(__nv_bfloat16* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count);
#endif
template __global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count);
#ifdef BF16_AVAILABLE
template __global__ void dequantize_kernel(__nv_bfloat16* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt);
#endif
template __global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
......@@ -10,6 +11,11 @@ namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048
// only used to avoid compilation error due to lack of definition.
#ifndef BF16_AVAILABLE
using __nv_bfloat162 = __half2;
#endif
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
......@@ -65,6 +71,13 @@ void launch_bias_gelu(T* input,
}
template void launch_bias_gelu<float>(float*, const float*, int, int, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_gelu<__nv_bfloat16>(__nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t);
/*
......@@ -115,6 +128,13 @@ void launch_bias_add(T* input,
}
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_add<__nv_bfloat16>(__nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t);
__global__ void fused_bias_residual(float* residual,
......@@ -162,16 +182,19 @@ __global__ void fused_bias_residual(float* residual,
}
}
__global__ void fused_bias_residual(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
template <typename T>
__global__ void fused_bias_residual(T* residual,
const T* hidden_state,
const T* attn,
const T* bias,
const T* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
......@@ -186,26 +209,26 @@ __global__ void fused_bias_residual(__half* residual,
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
T2* res_half2 = reinterpret_cast<T2*>(&res_fl2);
const T2* hs_half2 = reinterpret_cast<const T2*>(&hs_fl2);
const T2* attn_half2 = reinterpret_cast<const T2*>(&attn_fl2);
const T2* bias_half2 = reinterpret_cast<const T2*>(&bias_fl2);
const T2* attn_bias_half2 = reinterpret_cast<const T2*>(&attn_bias_fl2);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);
float2 res_low = conversion::to<float2>(res_half2[0]);
float2 res_high = conversion::to<float2>(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);
const float2 hs_low = conversion::to<float2>(hs_half2[0]);
const float2 hs_high = conversion::to<float2>(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);
const float2 attn_low = conversion::to<float2>(attn_half2[0]);
const float2 attn_high = conversion::to<float2>(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);
const float2 bias_low = conversion::to<float2>(bias_half2[0]);
const float2 bias_high = conversion::to<float2>(bias_half2[1]);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
const float2 attn_bias_low = conversion::to<float2>(attn_bias_half2[0]);
const float2 attn_bias_high = conversion::to<float2>(attn_bias_half2[1]);
if (preln) {
// residual = (residual + attention + bias + attention_bias) *
......@@ -225,8 +248,8 @@ __global__ void fused_bias_residual(__half* residual,
res_high.x = (res_high.x + hs_high.x + bias_high.x);
res_high.y = (res_high.y + hs_high.y + bias_high.y);
}
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);
res_half2[0] = conversion::to<T2>(res_low);
res_half2[1] = conversion::to<T2>(res_high);
res_fl2_ptr[offset] = res_fl2;
}
......@@ -261,9 +284,43 @@ void launch_bias_residual(T* residual,
template void launch_bias_residual<
float>(float*, float*, float*, float*, float*, int, int, int, bool, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_residual<__nv_bfloat16>(__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
int,
int,
int,
bool,
cudaStream_t);
#endif
template void launch_bias_residual<
__half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, cudaStream_t);
#ifdef BF16_AVAILABLE
template __global__ void fused_bias_residual(__nv_bfloat16* residual,
const __nv_bfloat16* hidden_state,
const __nv_bfloat16* attn,
const __nv_bfloat16* bias,
const __nv_bfloat16* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln);
#endif
template __global__ void fused_bias_residual(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln);
__global__ void gptj_residual_add(float* residual,
const float* hidden_state,
const float* attn,
......@@ -304,15 +361,18 @@ __global__ void gptj_residual_add(float* residual,
}
}
__global__ void gptj_residual_add(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
template <typename T>
__global__ void gptj_residual_add(T* residual,
const T* hidden_state,
const T* attn,
const T* bias,
const T* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale)
{
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
......@@ -326,28 +386,28 @@ __global__ void gptj_residual_add(__half* residual,
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
T2* res_half2 = reinterpret_cast<T2*>(&res_fl2);
const T2* hs_half2 = reinterpret_cast<const T2*>(&hs_fl2);
const T2* attn_half2 = reinterpret_cast<const T2*>(&attn_fl2);
const T2* bias_half2 = reinterpret_cast<const T2*>(&bias_fl2);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);
float2 res_low = conversion::to<float2>(res_half2[0]);
float2 res_high = conversion::to<float2>(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);
const float2 hs_low = conversion::to<float2>(hs_half2[0]);
const float2 hs_high = conversion::to<float2>(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);
const float2 attn_low = conversion::to<float2>(attn_half2[0]);
const float2 attn_high = conversion::to<float2>(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);
const float2 bias_low = conversion::to<float2>(bias_half2[0]);
const float2 bias_high = conversion::to<float2>(bias_half2[1]);
if (attn_bias) {
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
const T2* attn_bias_half2 = reinterpret_cast<const T2*>(&attn_bias_fl2);
const float2 attn_bias_low = conversion::to<float2>(attn_bias_half2[0]);
const float2 attn_bias_high = conversion::to<float2>(attn_bias_half2[1]);
// residual += attention_bias
res_low.x += attn_bias_low.x;
res_low.y += attn_bias_low.y;
......@@ -360,8 +420,8 @@ __global__ void gptj_residual_add(__half* residual,
res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale;
res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale;
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);
res_half2[0] = conversion::to<T2>(res_low);
res_half2[1] = conversion::to<T2>(res_high);
res_fl2_ptr[offset] = res_fl2;
}
......@@ -395,6 +455,19 @@ template void launch_gptj_residual_add<float>(float*,
int,
int,
cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_gptj_residual_add<__nv_bfloat16>(__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
__nv_bfloat16*,
int,
int,
int,
cudaStream_t);
#endif
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
......@@ -404,6 +477,27 @@ template void launch_gptj_residual_add<__half>(__half*,
int,
int,
cudaStream_t);
#ifdef BF16_AVAILABLE
template __global__ void gptj_residual_add(__nv_bfloat16* residual,
const __nv_bfloat16* hidden_state,
const __nv_bfloat16* attn,
const __nv_bfloat16* bias,
const __nv_bfloat16* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale);
#endif
template __global__ void gptj_residual_add(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale);
template <typename T>
__global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim)
{
......@@ -454,6 +548,16 @@ template void launch_moe_res_matmul(float* residual,
int seq_len,
int hidden_dim,
cudaStream_t stream);
#ifdef BF16_AVAILABLE
template void launch_moe_res_matmul(__nv_bfloat16* residual,
__nv_bfloat16* coef,
__nv_bfloat16* mlp_out,
int seq_len,
int hidden_dim,
cudaStream_t stream);
#endif
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
......@@ -461,11 +565,11 @@ template void launch_moe_res_matmul(__half* residual,
int hidden_dim,
cudaStream_t stream);
__global__ void pad_data_kernel(__half* padded_output,
__half* output,
int head_size,
int padded_head_size)
template <typename T>
__global__ void pad_data_kernel(T* padded_output, T* output, int head_size, int padded_head_size)
{
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bid = blockIdx.x * (blockDim.y) + threadIdx.y;
......@@ -473,8 +577,8 @@ __global__ void pad_data_kernel(__half* padded_output,
padded_output_cast += (bid * padded_head_size);
output_cast += (bid * head_size);
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
const T2 zero_h = conversion::to<T2>(0.f);
T2* ZERO_h = reinterpret_cast<T2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
if (idx < head_size)
......@@ -482,12 +586,14 @@ __global__ void pad_data_kernel(__half* padded_output,
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_data_kernel(float* padded_output,
float* output,
int head_size,
int padded_head_size)
{
}
template <typename T>
void pad_data(T* padded_output,
T* output,
......@@ -507,6 +613,16 @@ template void pad_data(__half* padded_output,
int head_size,
int padded_head_size,
cudaStream_t stream);
#ifdef BF16_AVAILABLE
template void pad_data(__nv_bfloat16* padded_output,
__nv_bfloat16* output,
int bsz,
int head_size,
int padded_head_size,
cudaStream_t stream);
#endif
template void pad_data(float* padded_output,
float* output,
int bsz,
......@@ -514,13 +630,28 @@ template void pad_data(float* padded_output,
int padded_head_size,
cudaStream_t stream);
__global__ void pad_head_seq_kernel(__half* padded_output,
__half* output,
#ifdef BF16_AVAILABLE
template __global__ void pad_data_kernel(__nv_bfloat16* padded_output,
__nv_bfloat16* output,
int head_size,
int padded_head_size);
#endif
template __global__ void pad_data_kernel(__half* padded_output,
__half* output,
int head_size,
int padded_head_size);
template <typename T>
__global__ void pad_head_seq_kernel(T* padded_output,
T* output,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size)
{
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bsz = blockIdx.x;
......@@ -529,8 +660,8 @@ __global__ void pad_head_seq_kernel(__half* padded_output,
padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size;
output_cast += (bsz * seq_len + bid) * head_size;
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
const T2 zero_h = conversion::to<T2>(0.f);
T2* ZERO_h = reinterpret_cast<T2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
......@@ -539,6 +670,7 @@ __global__ void pad_head_seq_kernel(__half* padded_output,
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_head_seq_kernel(float* padded_output,
float* output,
int seq_len,
......@@ -547,6 +679,7 @@ __global__ void pad_head_seq_kernel(float* padded_output,
int padded_head_size)
{
}
template <typename T>
void pad_head_seq(T* padded_output,
T* output,
......@@ -562,6 +695,7 @@ void pad_head_seq(T* padded_output,
pad_head_seq_kernel<<<grid_dim, block_dim, 0, stream>>>(
padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8);
}
template void pad_head_seq(__half* padded_output,
__half* output,
int bsz,
......@@ -570,6 +704,18 @@ template void pad_head_seq(__half* padded_output,
int head_size,
int padded_head_size,
cudaStream_t stream);
#ifdef BF16_AVAILABLE
template void pad_head_seq(__nv_bfloat16* padded_output,
__nv_bfloat16* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
cudaStream_t stream);
#endif
template void pad_head_seq(float* padded_output,
float* output,
int bsz,
......@@ -680,4 +826,12 @@ template void launch_fused_bias_geglu(__half*,
int,
int,
cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_bias_geglu(__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, cudaStream_t);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
......@@ -196,6 +197,16 @@ template void launch_fused_ln(__half*,
int,
int,
cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_ln(__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
float,
int,
int,
cudaStream_t);
#endif
template void
launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, cudaStream_t);
......@@ -492,6 +503,19 @@ template void launch_fused_residual_ln(__half*,
int,
cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_residual_ln(__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
float,
int,
int,
cudaStream_t);
#endif
template void launch_fused_residual_ln(float*,
const float*,
const float*,
......@@ -516,6 +540,20 @@ template void launch_fused_residual_ln_store_pre_ln_res(__half*,
int,
cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_fused_residual_ln_store_pre_ln_res(__nv_bfloat16*,
__nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
const __nv_bfloat16*,
float,
int,
int,
cudaStream_t);
#endif
template void launch_fused_residual_ln_store_pre_ln_res(float*,
float*,
const float*,
......
#include <limits>
#include "custom_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_bias_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
cudaStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
fused_residual_layer_norm<<<grid_dim, block_dim, 0, stream>>>(norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
......@@ -95,7 +96,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
head_offset,
mask_stride,
mp_size,
Context::Instance().GetCurrentStream(async_op));
InferenceContext::Instance().GetCurrentStream(async_op));
return attn_scores_c;
}
......@@ -109,18 +110,20 @@ void allocate_workspace(unsigned hidden_dim,
unsigned mp_size = 1,
bool external_cache = false,
unsigned rank = 0,
unsigned max_out_tokens = 1024)
unsigned max_out_tokens = 1024,
unsigned min_out_tokens = 1)
{
Context::Instance().GenWorkSpace(num_layers,
num_heads,
batch_size,
prompt_length,
hidden_dim,
mp_size,
external_cache,
sizeof(T),
rank,
max_out_tokens);
InferenceContext::Instance().GenWorkSpace(num_layers,
num_heads,
batch_size,
prompt_length,
hidden_dim,
mp_size,
external_cache,
sizeof(T),
rank,
max_out_tokens,
min_out_tokens);
}
template <typename T>
......@@ -131,15 +134,15 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
float alpha = 1;
float gemm_beta = 0.0;
/*
// Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) {
allocate_workspace<T>(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1,
head_size); workspace = (T*)Context::Instance().GetWorkSpace();
allocate_workspace<T>(W.size(1), InferenceContext::Instance().GetMaxTokenLenght(),
Q.size(0), 1, head_size); workspace = (T*)InferenceContext::Instance().GetWorkSpace();
}
*/
......@@ -147,7 +150,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_T,
m,
......@@ -194,8 +197,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
auto mask_stride = get_attn_mask_stride(attn_mask);
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
......@@ -230,9 +234,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
0,
mask_stride,
1,
Context::Instance().GetCurrentStream(false));
InferenceContext::Instance().GetCurrentStream(false));
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
......@@ -363,10 +367,11 @@ void attention_unfused(T* prev_key_cont,
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
float alpha = norm_factor * norm_factor / layer_scale;
float gemm_beta = 0.0;
T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace();
T* workspace = (T*)InferenceContext::Instance().GetAttentionUnfusedWorkspace();
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
......@@ -377,7 +382,7 @@ void attention_unfused(T* prev_key_cont,
workspace,
CUBLAS_OP_T,
CUBLAS_OP_N,
Context::Instance().GetMaxTokenLenght() * k,
InferenceContext::Instance().GetMaxTokenLenght() * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
......@@ -399,7 +404,7 @@ void attention_unfused(T* prev_key_cont,
soft_len,
heads);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
cublas_strided_batched_gemm(InferenceContext::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
......@@ -410,7 +415,7 @@ void attention_unfused(T* prev_key_cont,
(T*)output,
CUBLAS_OP_N,
CUBLAS_OP_N,
Context::Instance().GetMaxTokenLenght() * k,
InferenceContext::Instance().GetMaxTokenLenght() * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
......@@ -421,7 +426,7 @@ void attention_unfused(T* prev_key_cont,
#endif
}
void reset_cache() { Context::Instance().reset_tokens(); }
void reset_cache() { InferenceContext::Instance().reset_tokens(); }
template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
......@@ -445,8 +450,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bool is_prompt = (seq_len > 1);
if (is_prompt) Context::Instance().reset_tokens(seq_len);
unsigned soft_len = Context::Instance().current_tokens();
if (is_prompt) InferenceContext::Instance().reset_tokens(seq_len);
unsigned soft_len = InferenceContext::Instance().current_tokens();
int k = hidden_dim / heads;
auto options = at::TensorOptions()
......@@ -455,16 +460,17 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
size_t buf_size = bsz * seq_len * hidden_dim;
auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options);
auto query_cont = workspace + 8 * buf_size;
size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) +
layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
auto query_cont = workspace + 5 * buf_size;
size_t offset =
10 * (hidden_dim * bsz * InferenceContext::Instance().GetMaxTokenLenght()) +
layer_id * 2 * bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
unsigned all_tokens = soft_len;
auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1);
size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
size_t value_offset = bsz * InferenceContext::Instance().GetMaxTokenLenght() * hidden_dim;
T* temp_buf = (T*)output.data_ptr() + at::numel(output);
launch_bias_add_transform_0213<T>((T*)query_cont,
......@@ -481,9 +487,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
rotary_dim,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
InferenceContext::Instance().GetCurrentStream(),
3,
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetMaxTokenLenght());
if (rotary_dim > 0 && rotate_half)
launch_apply_rotary_pos_emb(query_cont,
kv_cache,
......@@ -495,8 +501,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLenght());
attention_unfused<T>(workspace + offset,
(T*)query_cont,
......@@ -521,13 +527,27 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
heads,
seq_len,
output.size(2),
Context::Instance().GetCurrentStream(false),
InferenceContext::Instance().GetCurrentStream(false),
1);
if (layer_id == num_layers - 1) Context::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset, {bsz, heads, all_tokens, k}, options);
if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset,
{bsz, heads, all_tokens, k},
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
k * InferenceContext::Instance().GetMaxTokenLenght(),
k,
1},
options);
auto prev_value =
torch::from_blob(workspace + offset + value_offset, {bsz, heads, all_tokens, k}, options);
torch::from_blob(workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * InferenceContext::Instance().GetMaxTokenLenght(),
k * InferenceContext::Instance().GetMaxTokenLenght(),
k,
1},
options);
return {output, prev_key, prev_value};
}
......@@ -543,7 +563,7 @@ at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
......@@ -569,14 +589,14 @@ at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias)
(const float*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_bias_geglu((__half*)output.data_ptr(),
(const __half*)activation.data_ptr(),
(const __half*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return output;
......@@ -594,7 +614,7 @@ at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
......@@ -610,7 +630,7 @@ at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias)
(T*)bias.data_ptr(),
hidden_size,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
......@@ -627,7 +647,7 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor&
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
// InferenceContext::Instance().GetCurrentStream());
return input_cont;
}
......@@ -645,7 +665,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
......@@ -654,7 +674,7 @@ at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return output;
......@@ -675,7 +695,7 @@ void ds_layer_norm_internal(T* workspace,
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
/* Currently only used in unit testing */
......@@ -700,7 +720,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
......@@ -711,7 +731,7 @@ at::Tensor ds_layer_norm_residual(at::Tensor& input,
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return output;
......@@ -741,7 +761,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(),
(float*)res_output.data_ptr(),
......@@ -753,7 +773,7 @@ std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& inpu
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
return {norm_output, res_output};
......@@ -768,7 +788,7 @@ void quantized_gemm(void* output,
int bsz,
int hidden_size)
{
// T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
// T* weight16 = (T*)InferenceContext::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
auto options = at::TensorOptions()
.dtype(at::kHalf)
......@@ -783,11 +803,11 @@ void quantized_gemm(void* output,
weight.size(0),
weight.size(1),
groups,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
weight.size(0),
......@@ -815,10 +835,11 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& beta,
const float epsilon,
bool add_bias,
bool q_int8)
bool q_int8,
bool transposed_mode)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
workspace += (3 * bsz * input.size(2));
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
......@@ -829,12 +850,12 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input.size(2),
&alpha,
......@@ -851,9 +872,9 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return torch::from_blob(workspace, input.sizes(), input.options());
}
......@@ -870,11 +891,12 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
bool external_cache,
unsigned mp_size,
unsigned rank,
bool q_int8)
bool q_int8,
bool transposed_mode)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
int out_size = q_int8 ? weight.size(0) : weight.size(1);
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
int out_size = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
......@@ -883,8 +905,17 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
.requires_grad(false);
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
auto inp_norm = qkv_unfused_cublas<T>(
output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8);
auto inp_norm = qkv_unfused_cublas<T>(output,
input,
weight,
q_scale,
bias,
gamma,
beta,
epsilon,
add_bias,
q_int8,
transposed_mode);
return {output, inp_norm};
}
......@@ -912,11 +943,11 @@ void quantized_gemm(at::Tensor& output,
weight.size(1),
groups,
merge_count,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
CUBLAS_OP_T,
CUBLAS_OP_N,
weight.size(0),
......@@ -963,7 +994,7 @@ at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return output;
}
......@@ -974,7 +1005,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
at::Tensor& bias,
bool add_bias,
bool do_flash_attn,
int num_heads)
int num_heads,
bool transposed_mode)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
......@@ -985,17 +1017,18 @@ at::Tensor ds_linear_layer(at::Tensor& input,
int head_size = input_cont.size(2) / num_heads;
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input_cont.size(2),
&alpha,
......@@ -1011,9 +1044,9 @@ at::Tensor ds_linear_layer(at::Tensor& input,
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0);
if (do_flash_attn) {
if (add_padding) {
......@@ -1026,7 +1059,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
3 * bsz * num_heads,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
launch_bias_add_transform_0213<T>(
final_output,
......@@ -1043,7 +1076,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
InferenceContext::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(final_output,
......@@ -1068,7 +1101,7 @@ at::Tensor ds_linear_layer(at::Tensor& input,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
InferenceContext::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(
......@@ -1086,7 +1119,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
{
int head_size = query.size(3);
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128;
pad_head_seq(workspace,
......@@ -1096,7 +1129,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
query.size(2),
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
pad_head_seq(key_pad_ptr,
(T*)key.data_ptr(),
query.size(0) * query.size(1),
......@@ -1104,7 +1137,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
pad_head_seq(value_pad_ptr,
(T*)value.data_ptr(),
query.size(0) * query.size(1),
......@@ -1112,7 +1145,7 @@ std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tens
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return {
at::from_blob(workspace,
{query.size(0), query.size(1), query.size(2), padded_head_size},
......@@ -1134,7 +1167,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
int key_value_length = add_padding ? 128 : key.size(1);
int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128))
: head_size;
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length;
launch_pad_add_transform_0213(workspace,
......@@ -1145,7 +1178,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
query.size(1),
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
launch_pad_add_transform_0213(key_pad_ptr,
(T*)key.data_ptr(),
key.size(0),
......@@ -1154,7 +1187,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
launch_pad_add_transform_0213(value_pad_ptr,
(T*)value.data_ptr(),
value.size(0),
......@@ -1163,7 +1196,7 @@ std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return {
at::from_blob(
workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()),
......@@ -1196,7 +1229,7 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return output;
}
......@@ -1205,7 +1238,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
at::Tensor& weight,
bool async_op,
at::Tensor& q_scale,
bool q_int8)
bool q_int8,
bool transposed_mode)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
......@@ -1215,7 +1249,7 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* workspace = (T*)InferenceContext::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
......@@ -1228,12 +1262,12 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input.size(2),
&alpha,
......@@ -1286,11 +1320,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
ActivationFuncType act_func_type)
ActivationFuncType act_func_type,
bool transposed_mode)
{
int bsz = input.size(0) * input.size(1);
T* inp_norm =
(T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output);
T* inp_norm = (T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input) +
torch::numel(output);
T* intermediate = inp_norm + torch::numel(input);
if (mlp_after_attn) {
......@@ -1303,7 +1338,7 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
}
......@@ -1313,12 +1348,12 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
weight.size(transposed_mode ? 0 : 1),
bsz,
input.size(2),
&alpha,
......@@ -1335,15 +1370,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
if (act_func_type == ActivationFuncType::GELU) {
launch_bias_gelu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
} else if (act_func_type == ActivationFuncType::ReLU) {
launch_bias_relu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
(transposed_mode || q_int8) ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
}
if (q_int8) {
......@@ -1357,14 +1392,14 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
weight1.size(1),
weight1.size(transposed_mode ? 0 : 1),
bsz,
weight1.size(0),
weight1.size(transposed_mode ? 1 : 0),
&alpha,
&gemm_beta,
(T*)weight1.data_ptr(),
......@@ -1395,7 +1430,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
int activation_type)
int activation_type,
bool transposed_mode)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
......@@ -1403,10 +1439,11 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size},
options);
int out_size = (q_int8 || transposed_mode) ? weight_out.size(0) : weight_out.size(1);
auto output =
at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size},
options);
int bsz = input.size(0) * input.size(1);
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
......@@ -1425,7 +1462,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
q_scale,
q_scale1,
q_int8,
act_func_type);
act_func_type,
transposed_mode);
return {output, res_add};
}
......@@ -1461,7 +1499,7 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return {output, residual_add};
}
......@@ -1476,7 +1514,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
const float epsilon,
bool preLayerNorm,
bool q_int8,
bool async_op)
bool async_op,
bool transposed_mode)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
......@@ -1484,9 +1523,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);
int intm_dim = q_int8 ? weight.size(0) : weight.size(1);
int intm_dim = (transposed_mode || q_int8) ? weight.size(0) : weight.size(1);
// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
// auto output = at::from_blob((T*)InferenceContext::Instance().GetWorkSpace() +
// torch::numel(input),
// {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
......@@ -1505,10 +1545,10 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
bsz,
input.size(2));
} else {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublasSetStream(InferenceContext::Instance().GetCublasHandle(),
InferenceContext::Instance().GetCurrentStream());
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
intm_dim,
bsz,
......@@ -1528,9 +1568,9 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
(T*)bias.data_ptr(),
intm_dim,
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
int out_size = (transposed_mode || q_int8) ? weight_out.size(0) : weight_out.size(1);
auto output = at::empty({input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
......@@ -1541,8 +1581,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
bsz,
input.size(2));
} else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
cublas_gemm_ex(InferenceContext::Instance().GetCublasHandle(),
(transposed_mode ? CUBLAS_OP_T : CUBLAS_OP_N),
CUBLAS_OP_N,
out_size,
bsz,
......@@ -1558,8 +1598,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
// cudaEventRecord(InferenceContext::Instance().GetCompEvent(2),
// InferenceContext::Instance().GetCurrentStream(true));
return output;
}
......@@ -1586,7 +1626,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
hidden_size,
mp_size,
preln,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
else
launch_gptj_residual_add<T>(
static_cast<T*>(residual.data_ptr()),
......@@ -1597,7 +1637,7 @@ at::Tensor& residual_add_bias(at::Tensor& hidden_state,
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return residual;
}
......@@ -1627,8 +1667,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLenght());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
......@@ -1640,8 +1680,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLenght());
return {query_cont, key_cont};
}
......@@ -1670,7 +1710,7 @@ at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
InferenceContext::Instance().GetCurrentStream());
return output;
}
......@@ -1679,7 +1719,7 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
{
int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2);
Context::Instance().SynchComm();
InferenceContext::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(),
......@@ -1698,83 +1738,77 @@ at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& out
return output;
}
void ds_release_workspace() { InferenceContext::Instance().release_workspace(); }
bool ds_retake_workspace() { return InferenceContext::Instance().retake_workspace(); }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp16 (CUDA)");
m.def("softmax_context_int8",
&ds_softmax_context1<__half>,
"DeepSpeed attention with int8 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_geglu", &ds_bias_geglu, "DeepSpeed Bias GEGLU (CUDA)");
m.def("bias_add_fp32", &ds_bias_add<float>, "DeepSpeed Bias Add with fp32 (CUDA)");
m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_relu_fp32", &ds_bias_relu<float>, "DeepSpeed ReLU with fp32 (CUDA)");
m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp16 (CUDA)");
m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)");
m.def(
"_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)");
m.def("layer_norm_residual_store_pre_ln_res",
&ds_layer_norm_residual_store_pre_ln_res,
"DeepSpeed layer norm + store pre Layernorm residual (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
m.def("mlp_gemm_fp32", &ds_mlp_gemm<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)");
m.def("vector_matmul_fp32", &ds_vector_matmul<float>, "DeepSpeed vector-MM with fp32 (CUDA)");
m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)");
m.def("vector_matmul_int8",
&ds_vector_matmul_int8<__half>,
"DeepSpeed vector-MM with int8 (CUDA)");
m.def("linear_layer_fp32", &ds_linear_layer<float>, "DeepSpeed linear_layer with fp32 (CUDA)");
m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)");
m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add_bias_fp32",
&residual_add_bias<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("residual_add_bias_fp16",
&residual_add_bias<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
"DeepSpeed vector-MM with fp32 (CUDA)");
m.def("einsum_sec_sm_ecm_fp16",
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
m.def("add_padding_fp32", &add_padding<float>, "DeepSpeed residual add with fp32 (CUDA)");
m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)");
m.def("pad_transform_fp32",
&padd_add_transform<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("pad_transform_fp16",
&padd_add_transform<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("allocate_workspace_fp32",
&allocate_workspace<float>,
"DeepSpeed memory allocation for GPT inference with fp32 (CUDA)");
m.def("allocate_workspace_fp16",
&allocate_workspace<__half>,
"DeepSpeed memory allocation for GPT inference with fp16 (CUDA)");
m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks");
m.def("release_workspace", &ds_release_workspace, "DeepSpeed Release Workspace");
m.def("retake_workspace", &ds_retake_workspace, "DeepSpeed Retake Workspace");
#define DEF_OPS(_name, _dtype) \
m.def("softmax_" #_name, &ds_softmax<_dtype>, "DeepSpeed SoftMax with " #_name " (CUDA)"); \
m.def("softmax_context_" #_name, \
&ds_softmax_context<_dtype>, \
"DeepSpeed attention with _name (CUDA)"); \
m.def("bias_gelu_" #_name, &ds_bias_gelu<_dtype>, "DeepSpeed Gelu with " #_name " (CUDA)"); \
m.def("bias_add_" #_name, &ds_bias_add<_dtype>, "DeepSpeed Bias Add with " #_name " (CUDA)"); \
m.def("bias_relu_" #_name, &ds_bias_relu<_dtype>, "DeepSpeed ReLU with " #_name " (CUDA)"); \
m.def("bias_residual_" #_name, \
&ds_bias_residual<_dtype>, \
"DeepSpeed residual-bias add with " #_name " (CUDA)"); \
m.def("qkv_gemm_" #_name, &ds_qkv_gemm<_dtype>, "DeepSpeed qkv gemm with " #_name " (CUDA)"); \
m.def("mlp_gemm_" #_name, &ds_mlp_gemm<_dtype>, "DeepSpeed mlp with " #_name " (CUDA)"); \
m.def("vector_matmul_" #_name, \
&ds_vector_matmul<_dtype>, \
"DeepSpeed vector-MM with " #_name " (CUDA)"); \
m.def("linear_layer_" #_name, \
&ds_linear_layer<_dtype>, \
"DeepSpeed linear_layer with " #_name " (CUDA)"); \
m.def("fused_gemm_gelu_" #_name, \
&fused_gemm_gelu<_dtype>, \
"DeepSpeed mlp with " #_name " (CUDA)"); \
m.def("residual_add_bias_" #_name, \
&residual_add_bias<_dtype>, \
"DeepSpeed residual add with " #_name " (CUDA)"); \
m.def("einsum_sec_sm_ecm_" #_name, \
&einsum_sec_sm_ecm<_dtype>, \
"DeepSpeed vector-MM with " #_name " (CUDA)"); \
m.def("add_padding_" #_name, \
&add_padding<_dtype>, \
"DeepSpeed residual add with " #_name " (CUDA)"); \
m.def("pad_transform_" #_name, \
&padd_add_transform<_dtype>, \
"DeepSpeed residual add with " #_name " (CUDA)"); \
m.def("allocate_workspace_" #_name, \
&allocate_workspace<_dtype>, \
"DeepSpeed memory allocation for GPT inference with " #_name " (CUDA)")
DEF_OPS(fp32, float);
DEF_OPS(fp16, __half);
#ifdef BF16_AVAILABLE
DEF_OPS(bf16, __nv_bfloat16);
#endif
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
......@@ -60,4 +61,11 @@ void launch_bias_relu(T* input,
}
template void launch_bias_relu<float>(float*, const float*, int, int, cudaStream_t);
#ifdef BF16_AVAILABLE
template void launch_bias_relu<__nv_bfloat16>(__nv_bfloat16*,
const __nv_bfloat16*,
int,
int,
cudaStream_t);
#endif
template void launch_bias_relu<__half>(__half*, const __half*, int, int, cudaStream_t);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment