Unverified Commit 7f2dcf91 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[Pytorch] Decoupling framework extensions from common module (#1498)



* Remove dependency on transformer_engine::Tensor in attention.cu
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Templatize thd_partition_indices_kernel and thd_read_half_tensor_kernel kernels ONLY for invoking recompilation and not directly using the pre-compiled symbols in libtransformer.so
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Modify attention.cu for thd templatized kernels. Remove dependency on common.h
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move thd structs from libtransformer.so to framework extensions include header

Code cleanup
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Consolidate and move thd_utils from common to framework extensions
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Remove template decorators around thd_partition_indices_kernel and thd_read_half_tensor_kernel
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b4fbc2b3
......@@ -65,7 +65,6 @@ list(APPEND transformer_engine_SOURCES
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/thd_utils.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
normalization/common.cpp
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../cudnn_utils.h"
#include "thd_utils.h"
namespace transformer_engine {
namespace fused_attn {
__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
int seqlen = cu_seqlens[i];
// Currently we assume that each sequence length is divisible by (world_size*2) since we have
// to distribute each sequence evenly to different GPUs.
assert(seqlen % (world_size * 2) == 0);
cu_seqlens_s[i] = seqlen / world_size;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
}
}
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_total_tokens = cu_seqlens_s[batch];
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes;
half = reinterpret_cast<void *>(reinterpret_cast<char *>(half) + offset / 2 * blockIdx.y);
tensor = reinterpret_cast<void *>(reinterpret_cast<char *>(tensor) + offset * blockIdx.y);
for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) {
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1);
size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes;
float4 *cur_half_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(half) + offset_in_bytes);
offset_in_bytes =
(static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes;
float4 *cur_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(tensor) + offset_in_bytes);
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) {
cur_half_token[idx] = cur_token[idx];
}
}
}
} // namespace fused_attn
} // namespace transformer_engine
......@@ -3,12 +3,8 @@
*
* See LICENSE for license information.
************************************************************************/
#include "common/common.h"
#include "common/fused_attn/thd_utils.h"
#include "extensions.h"
using namespace transformer_engine::fused_attn;
#include "thd_utils.cuh"
constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
......@@ -208,28 +204,40 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> output_tensors;
output_tensors.push_back(o_python);
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) {
if (i < nvte_aux_tensor_pack.size - 2) {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
} else if (i == nvte_aux_tensor_pack.size - 2) {
output_tensor = rng_state;
} else if (i == nvte_aux_tensor_pack.size - 1) {
output_tensor = Bias.value();
}
} else {
output_tensor = (i < nvte_aux_tensor_pack.size - 1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false)
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor =
(i < nvte_aux_tensor_pack.size - 1)
? allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false)
: rng_state;
}
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]);
output_tensor = allocateSpace(
nvte_shape_to_vector(temp_shape),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
}
output_tensors.push_back(py::cast(output_tensor));
tensor->data.dptr = output_tensor.data_ptr();
NVTEBasicTensor temp_data = {output_tensor.data_ptr(),
nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]),
nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
}
// execute the kernel
......@@ -425,11 +433,14 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[i]);
tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr();
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
tensor->data.shape = std::vector<size_t>(tmp.begin(), tmp.end());
tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type());
auto temp_vec = std::vector<size_t>(tmp.begin(), tmp.end());
const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()};
NVTEBasicTensor temp_data = {
Aux_CTX_Tensors[i].data_ptr(),
static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())),
temp_shape};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
}
// create dBias the same shape as Bias
......@@ -662,8 +673,8 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
grid_y *= tensor.size(i);
}
dim3 grid = {grid_x, grid_y};
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1),
at::cuda::getCurrentCUDAStream()>>>(
transformer_engine::fused_attn::thd_read_half_tensor_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr<int>(), batch, hidden_size_in_bytes,
half_idx, tensor.size(seq_dim));
......@@ -713,13 +724,14 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<double, true, LseCorrectionFunctor>
transformer_engine::fused_attn::thd_lse_kernel<double, true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
thd_lse_kernel<double, false, LseCorrectionFunctor>
transformer_engine::fused_attn::thd_lse_kernel<double, false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<double>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
......@@ -764,13 +776,14 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<float, true, ReadLseFunctor>
transformer_engine::fused_attn::thd_lse_kernel<float, true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
thd_lse_kernel<float, false, ReadLseFunctor>
transformer_engine::fused_attn::thd_lse_kernel<float, false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
......@@ -829,13 +842,13 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_
dim3 grid = {grid_x, (unsigned int)num_heads};
if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen, lse_per_step_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
......@@ -925,7 +938,8 @@ static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_p
}
dim3 grid = {grid_x, grid_y};
thd_grad_correction_kernel<dtype, Functor_0, Functor_1, functor_idx, 32>
transformer_engine::fused_attn::thd_grad_correction_kernel<dtype, Functor_0, Functor_1,
functor_idx, 32>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<dtype>(), grad_per_step.data_ptr<dtype>(), cu_seqlens.data_ptr<int>(),
batch, hidden_size, total_tokens);
......@@ -992,8 +1006,8 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
constexpr unsigned int block = 256;
unsigned int grid = (output.size(0) + block - 1) / block;
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1),
at::cuda::getCurrentCUDAStream()>>>(
transformer_engine::fused_attn::thd_partition_indices_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<int>(), cu_seqlens.data_ptr<int>(), batch, total_tokens, world_size, rank);
return output;
......
......@@ -3,13 +3,59 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_H_
#include <assert.h>
#include <cuda.h>
#include <cuda_bf16.h>
struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx,
size_t half_idx) {
double val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
}
};
struct ReadLseFunctor {
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
half_lse[half_idx] = lse[idx];
}
};
struct EmptyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {}
};
struct CopyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {
reinterpret_cast<float4 *>(token)[idx] = reinterpret_cast<float4 *>(token_per_step)[idx];
}
};
template <typename dtype>
struct AddFunctor {
__forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) {
float4 d_ = reinterpret_cast<float4 *>(token)[idx];
dtype *p_ = reinterpret_cast<dtype *>(&d_);
float4 d = reinterpret_cast<float4 *>(token_per_step)[idx];
dtype *p = reinterpret_cast<dtype *>(&d);
#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
}
reinterpret_cast<float4 *>(token)[idx] = d_;
}
};
namespace transformer_engine {
namespace fused_attn {
......@@ -33,39 +79,74 @@ __forceinline__ __device__ int binary_search(int target, int *array, int len) {
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch,
int total_tokens, int world_size, int rank);
int total_tokens, int world_size, int rank) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
int seqlen = cu_seqlens[i];
// Currently we assume that each sequence length is divisible by (world_size*2) since we have
// to distribute each sequence evenly to different GPUs.
assert(seqlen % (world_size * 2) == 0);
cu_seqlens_s[i] = seqlen / world_size;
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = blockDim.x * gridDim.x;
for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) {
int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1);
int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id];
int index = token_id - cu_seqlens_s[seq_id];
int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank;
index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset;
output[token_id] = index;
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token);
int dim_size_of_token) {
extern __shared__ int cu_seqlens_s[];
for (int i = threadIdx.x; i <= batch; i += blockDim.x) {
cu_seqlens_s[i] = cu_seqlens[i] / 2;
}
__syncthreads();
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
int laneid = threadIdx.x % 32;
int num_warps = (blockDim.x * gridDim.x) / 32;
int num_total_tokens = cu_seqlens_s[batch];
int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4);
struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx,
size_t half_idx) {
double val = lse[idx];
float val_per_step = half_lse[half_idx];
double max_scale = max(val, val_per_step);
double min_scale = min(val, val_per_step);
lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale));
}
};
size_t offset = static_cast<size_t>(dim_size_of_token) * hidden_size_in_bytes;
half = reinterpret_cast<void *>(reinterpret_cast<char *>(half) + offset / 2 * blockIdx.y);
tensor = reinterpret_cast<void *>(reinterpret_cast<char *>(tensor) + offset * blockIdx.y);
struct ReadLseFunctor {
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
half_lse[half_idx] = lse[idx];
for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) {
int seqid = binary_search(token_id, cu_seqlens_s, batch + 1);
size_t offset_in_bytes = static_cast<size_t>(token_id) * hidden_size_in_bytes;
float4 *cur_half_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(half) + offset_in_bytes);
offset_in_bytes =
(static_cast<size_t>(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes;
float4 *cur_token =
reinterpret_cast<float4 *>(reinterpret_cast<char *>(tensor) + offset_in_bytes);
for (int idx = laneid; idx < num_float4s_per_token; idx += 32) {
cur_half_token[idx] = cur_token[idx];
}
};
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
template <typename lse_dtype, bool lse_packed, typename Functor>
__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch,
......@@ -163,34 +244,6 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
struct EmptyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {}
};
struct CopyFunctor {
__forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {
reinterpret_cast<float4 *>(token)[idx] = reinterpret_cast<float4 *>(token_per_step)[idx];
}
};
template <typename dtype>
struct AddFunctor {
__forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) {
float4 d_ = reinterpret_cast<float4 *>(token)[idx];
dtype *p_ = reinterpret_cast<dtype *>(&d_);
float4 d = reinterpret_cast<float4 *>(token_per_step)[idx];
dtype *p = reinterpret_cast<dtype *>(&d);
#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
}
reinterpret_cast<float4 *>(token)[idx] = d_;
}
};
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx, int group_size>
__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens,
int batch, int hidden_size, int dim_size_of_token) {
......@@ -246,5 +299,4 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in
} // namespace fused_attn
} // namespace transformer_engine
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment