Unverified Commit 60ce21f4 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

[Common] Moved framework agnostic THD kernels to common. (#1339)



Moved framework agnostic THD kernels to common.

---------
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent ae393e81
...@@ -70,6 +70,7 @@ jobs: ...@@ -70,6 +70,7 @@ jobs:
run: pip install . -v run: pip install . -v
env: env:
NVTE_FRAMEWORK: jax NVTE_FRAMEWORK: jax
MAX_JOBS: 1
- name: 'Sanity check' - name: 'Sanity check'
run: python tests/jax/test_sanity_import.py run: python tests/jax/test_sanity_import.py
paddle: paddle:
......
...@@ -61,6 +61,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -61,6 +61,7 @@ list(APPEND transformer_engine_SOURCES
activation/swiglu.cu activation/swiglu.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp fused_attn/fused_attn.cpp
fused_attn/thd_utils.cu
fused_attn/utils.cu fused_attn/utils.cu
gemm/cublaslt_gemm.cu gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp layer_norm/ln_api.cpp
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../cudnn_utils.h"
#include "thd_utils.h"
namespace transformer_engine {
namespace fused_attn {
__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
int seqlen = cu_seqlens[i];
// Currently we assume that each sequence length is divisible by (world_size*2) since we have
// to distribute each sequence evenly to different GPUs.
assert(seqlen % (world_size * 2) == 0);
cu_seqlens_s[i] = seqlen / world_size;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
}
}
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_total_tokens = cu_seqlens_s[batch];
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes;
half = reinterpret_cast<void *>(reinterpret_cast<char *>(half) + offset / 2 * blockIdx.y);
tensor = reinterpret_cast<void *>(reinterpret_cast<char *>(tensor) + offset * blockIdx.y);
for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) {
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1);
size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes;
float4 *cur_half_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(half) + offset_in_bytes);
offset_in_bytes =
(static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes;
float4 *cur_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(tensor) + offset_in_bytes);
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) {
cur_half_token[idx] = cur_token[idx];
}
}
}
} // namespace fused_attn
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_
#include <cuda.h>
#include <cuda_bf16.h>
namespace transformer_engine {
namespace fused_attn {
/***************************************************************************************************
* Support THD format for Context Parallel: Binary search an array for a target value
**************************************************************************************************/
__forceinline__ __device__ int binary_search(int target, int *array, int len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank);
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token);
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx,
size_t half_idx) {
double val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
}
};
struct ReadLseFunctor {
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
half_lse[half_idx] = lse[idx];
}
};
template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int total_tokens) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
int num_total_tokens = cu_seqlens_s[batch];
for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * total_tokens + col + seq_len;
half_idx = row * total_tokens / 2 + col;
}
Functor::run(lse, half_lse, idx, half_idx);
}
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
}
__syncthreads();
int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size;
int lane_id = threadIdx.x % tile_size;
int num_tiles = (blockDim.x * gridDim.x) / tile_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4);
for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step;
if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx = (idx * num_heads + head_id) * dim_per_head;
idx_per_step = (static_cast<size_t>(token_id) * num_heads + head_id) * dim_per_head;
dtype *cur_out = out + idx;
dtype *cur_out_per_step = out_per_step + idx_per_step;
for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
float4 data_per_step = reinterpret_cast<float4 *>(cur_out_per_step)[j];
float4 data = reinterpret_cast<float4 *>(cur_out)[j];
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
}
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
struct EmptyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {}
};
struct CopyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {
reinterpret_cast<float4 *>(token)[idx] = reinterpret_cast<float4 *>(token_per_step)[idx];
}
};
template <typename dtype>
struct AddFunctor {
__forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) {
float4 d_ = reinterpret_cast<float4 *>(token)[idx];
dtype *p_ = reinterpret_cast<dtype *>(&d_);
float4 d = reinterpret_cast<float4 *>(token_per_step)[idx];
dtype *p = reinterpret_cast<dtype *>(&d);
#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
}
reinterpret_cast<float4 *>(token)[idx] = d_;
}
};
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx, int group_size>
__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens,
int batch, int hidden_size, int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
if constexpr (functor_idx < 2) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
} else {
cu_seqlens_s[i] = cu_seqlens[i];
}
}
__syncthreads();
int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size;
int lane_id = threadIdx.x % group_size;
int num_groups = (blockDim.x * gridDim.x) / group_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4);
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size;
if constexpr (functor_idx < 2) {
grad_per_step = grad_per_step + offset / 2 * blockIdx.y;
} else {
grad_per_step = grad_per_step + offset * blockIdx.y;
}
grad = grad + offset * blockIdx.y;
for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int token_offset;
bool is_first_half;
if constexpr (functor_idx < 2) {
token_offset = cu_seqlens_s[seq_id + functor_idx];
is_first_half = (functor_idx == 0);
} else {
token_offset = 0;
int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2);
}
dtype *token = &grad[(token_id + token_offset) * static_cast<size_t>(hidden_size)];
dtype *token_per_step = &grad_per_step[token_id * static_cast<size_t>(hidden_size)];
for (int idx = lane_id; idx < num_inner_loops; idx += group_size) {
if (is_first_half) {
Functor_0::run(token, token_per_step, idx);
} else {
Functor_1::run(token, token_per_step, idx);
}
}
}
}
} // namespace fused_attn
} // namespace transformer_engine
#endif
...@@ -4,8 +4,11 @@ ...@@ -4,8 +4,11 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common/fused_attn/thd_utils.h"
#include "extensions.h" #include "extensions.h"
using namespace transformer_engine::fused_attn;
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
...@@ -1359,64 +1362,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { ...@@ -1359,64 +1362,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
return qkv; return qkv;
} }
/***************************************************************************************************
* Support THD format for Context Parallel: Binary search
**************************************************************************************************/
__forceinline__ __device__ int binary_search(int target, int *array, int len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}
/*************************************************************************************************** /***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor * Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/ **************************************************************************************************/
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_total_tokens = cu_seqlens_s[batch];
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes;
half = reinterpret_cast<void *>(reinterpret_cast<char *>(half) + offset / 2 * blockIdx.y);
tensor = reinterpret_cast<void *>(reinterpret_cast<char *>(tensor) + offset * blockIdx.y);
for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) {
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1);
size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes;
float4 *cur_half_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(half) + offset_in_bytes);
offset_in_bytes =
(static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes;
float4 *cur_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(tensor) + offset_in_bytes);
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) {
cur_half_token[idx] = cur_token[idx];
}
}
}
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
int half_idx) { int half_idx) {
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
...@@ -1464,51 +1413,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -1464,51 +1413,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
* Support THD format for Context Parallel: softmax_lse related operations * Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/ **************************************************************************************************/
template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
int num_heads, int total_tokens) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
int num_total_tokens = cu_seqlens_s[batch];
for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, half_idx;
if constexpr (lse_packed) {
idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1];
half_idx = head_id * total_tokens / 2 + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * total_tokens + col + seq_len;
half_idx = row * total_tokens / 2 + col;
}
Functor::run(lse, half_lse, idx, half_idx);
}
}
}
struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx,
size_t half_idx) {
double val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
}
};
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) { const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double);
...@@ -1559,13 +1463,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st ...@@ -1559,13 +1463,6 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
} }
} }
struct ReadLseFunctor {
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
half_lse[half_idx] = lse[idx];
}
};
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed) { bool lse_packed) {
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
...@@ -1620,59 +1517,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ ...@@ -1620,59 +1517,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
* Support THD format for Context Parallel: Out correction in forward * Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/ **************************************************************************************************/
template <typename dtype, int only_second_half, int tile_size, bool lse_packed>
__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse,
float *lse_per_step, int *cu_seqlens, int batch,
int num_heads, int dim_per_head, int lse_seqlen) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1);
}
__syncthreads();
int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size;
int lane_id = threadIdx.x % tile_size;
int num_tiles = (blockDim.x * gridDim.x) / tile_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4);
for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) {
size_t idx, idx_per_step;
if constexpr (lse_packed) {
idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id;
} else {
size_t row = static_cast<size_t>(seq_id) * num_heads + head_id;
int col = token_id - cu_seqlens_s[seq_id];
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
idx = row * lse_seqlen + col + seq_len * only_second_half;
idx_per_step = row * lse_seqlen / (only_second_half + 1) + col;
}
float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]);
idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half;
idx = (idx * num_heads + head_id) * dim_per_head;
idx_per_step = (static_cast<size_t>(token_id) * num_heads + head_id) * dim_per_head;
dtype *cur_out = out + idx;
dtype *cur_out_per_step = out_per_step + idx_per_step;
for (int j = lane_id; j < num_loops_per_head; j += tile_size) {
float4 data_per_step = reinterpret_cast<float4 *>(cur_out_per_step)[j];
float4 data = reinterpret_cast<float4 *>(cur_out)[j];
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
}
}
}
template <typename dtype, int only_second_half> template <typename dtype, int only_second_half>
static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step,
const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step,
...@@ -1773,87 +1617,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at ...@@ -1773,87 +1617,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at
* Support THD format for Context Parallel: Gradients correction in backward * Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/ **************************************************************************************************/
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx, int group_size>
__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens,
int batch, int hidden_size, int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
if constexpr (functor_idx < 2) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
} else {
cu_seqlens_s[i] = cu_seqlens[i];
}
}
__syncthreads();
int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size;
int lane_id = threadIdx.x % group_size;
int num_groups = (blockDim.x * gridDim.x) / group_size;
int num_total_tokens = cu_seqlens_s[batch];
int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4);
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size;
if constexpr (functor_idx < 2) {
grad_per_step = grad_per_step + offset / 2 * blockIdx.y;
} else {
grad_per_step = grad_per_step + offset * blockIdx.y;
}
grad = grad + offset * blockIdx.y;
for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int token_offset;
bool is_first_half;
if constexpr (functor_idx < 2) {
token_offset = cu_seqlens_s[seq_id + functor_idx];
is_first_half = (functor_idx == 0);
} else {
token_offset = 0;
int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2);
}
dtype *token = &grad[(token_id + token_offset) * static_cast<size_t>(hidden_size)];
dtype *token_per_step = &grad_per_step[token_id * static_cast<size_t>(hidden_size)];
for (int idx = lane_id; idx < num_inner_loops; idx += group_size) {
if (is_first_half) {
Functor_0::run(token, token_per_step, idx);
} else {
Functor_1::run(token, token_per_step, idx);
}
}
}
}
struct EmptyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {}
};
struct CopyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {
reinterpret_cast<float4 *>(token)[idx] = reinterpret_cast<float4 *>(token_per_step)[idx];
}
};
template <typename dtype>
struct AddFunctor {
__forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) {
float4 d_ = reinterpret_cast<float4 *>(token)[idx];
dtype *p_ = reinterpret_cast<dtype *>(&d_);
float4 d = reinterpret_cast<float4 *>(token_per_step)[idx];
dtype *p = reinterpret_cast<dtype *>(&d);
#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
}
reinterpret_cast<float4 *>(token)[idx] = d_;
}
};
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx> template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx>
static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step, static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens) { const at::Tensor &cu_seqlens) {
...@@ -1945,31 +1708,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, ...@@ -1945,31 +1708,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
* Support THD format for Context Parallel: Generate partitioned indices for input tokens * Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/ **************************************************************************************************/
__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
int seqlen = cu_seqlens[i];
// Currently we assume that each sequence length is divisible by (world_size*2) since we have
// to distribute each sequence evenly to different GPUs.
assert(seqlen % (world_size * 2) == 0);
cu_seqlens_s[i] = seqlen / world_size;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
}
}
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank) { int world_size, int rank) {
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment