Commit 1f9c104b authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 2b1428ff 8a03ff34
......@@ -263,30 +263,16 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test,
const float* ref, const size_t rows,
const size_t col_blocks
#ifdef __HIP_PLATFORM_AMD__
, double atol = 0., double rtol = 0.
#endif
) {
const size_t col_blocks) {
const size_t test_stride = scale_align_stride(rows);
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i + test_stride * j;
const int ref_idx = i + rows * j;
#ifdef __HIP_PLATFORM_AMD__
double t = static_cast<double>(static_cast<float>(test[test_idx]));
double r = static_cast<double>(static_cast<float>(ref[ref_idx]));
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch)
<< "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r << " at index " << test_idx
<< "," << ref_idx;
#else
ASSERT_FALSE(test[test_idx] != ref[ref_idx])
<< "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx;
#endif
}
}
}
......@@ -425,33 +411,17 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
float atol = 0.0;
float rtol = 0.0;
#ifdef __HIP_PLATFORM_AMD__
double atol_scale = 0.0;
double rtol_scale = 0.0;
if(itype == DType::kFloat32)
{
atol_scale = 1e-5;
}
#endif
if (rowwise) {
compareResults("output_c", output_c, ref_output.get(), true, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv",
output_c.rowwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv.get(), rows, blocks_x
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
ref_scale_inv.get(), rows, blocks_x);
}
if (colwise) {
compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv_t",
output_c.columnwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv_t.get(), cols, blocks_x_t
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
ref_scale_inv_t.get(), cols, blocks_x_t);
}
}
......
......@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
qx = x_tiled * scale.reshape(M, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
if quant_dtype == torch.int8:
qx = torch.round(qx)
positive_mask = qx >= 0
negative_mask = ~positive_mask
pos_part = torch.where(positive_mask, torch.floor(qx + 0.5), 0)
neg_part = torch.where(negative_mask, torch.ceil(qx - 0.5), 0)
qx = pos_part + neg_part
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
return qx, scale_inv
......
......@@ -4,7 +4,7 @@
from typing import Tuple
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def scale_from_amax_tensor(
x_dtype: torch.dtype,
......@@ -48,7 +48,11 @@ def scale_from_amax_tensor(
# No subnormals and zero.
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
if IS_HIP_EXTENSION:
host_scale = torch.ldexp(unity.cpu(), exp.cpu())
scale = host_scale.to(exp.device)
else:
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
......
......@@ -273,7 +273,7 @@ def check_quantization_block_tiling_versus_reference(
)
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0)
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask = torch.ones(
......@@ -283,7 +283,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx, scale_mask, None, None), tile_size
).scale
sx = sx * scale_mask
torch.testing.assert_close(sx, sx_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
if return_transpose:
assert qx_t is not None
......@@ -299,8 +299,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx_t, scale_mask, None, None), tile_size
).scale
sx_t = sx_t * scale_mask
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0 if x_dtype != torch.float32 else 2.5e-1)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5)
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
else:
# should be None
assert qx_t is None and qx_t_ref is None
......
......@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
OType scaled_elt = 0;
if constexpr(std::is_same_v<OType, int8_t>) {
scaled_elt =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
}
else {
scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
}
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
......
......@@ -27,90 +27,6 @@
#include "common/utils.cuh"
namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
__device__ bool is_little_endian()
{
int num = 1;
const char* ptr = reinterpret_cast<const char*>(&num);
if(*ptr == 1)
{
return true;
}
else
{
return false;
}
}
struct BitFloat
{
private:
char data[3];
public:
__device__ BitFloat(const float val, bool pow2scale)
{
uint32_t raw_val = *reinterpret_cast<const uint32_t*>(&val);
if (~raw_val & 0x7f800000)
{
if(pow2scale && (raw_val & 0x000000FF))
{
raw_val |= 0x100;
}
else
{
raw_val += 0x7f + ((raw_val >> 8) & 1);
}
}
else if (raw_val & 0xffff)
{
raw_val |= 0x100;
}
raw_val = (raw_val >> 8);
const char* ptr = reinterpret_cast<const char*>(&raw_val);
if(is_little_endian())
{
data[0] = ptr[0];
data[1] = ptr[1];
data[2] = ptr[2];
}
else
{
data[0] = ptr[1];
data[1] = ptr[2];
data[2] = ptr[3];
}
}
__device__ operator float() const
{
uint32_t raw_val = 0;
char* ptr = reinterpret_cast<char*>(&raw_val);
if(is_little_endian())
{
ptr[1] = data[0];
ptr[2] = data[1];
ptr[3] = data[2];
}
else
{
ptr[0] = data[0];
ptr[1] = data[1];
ptr[2] = data[2];
}
return *reinterpret_cast<const float*>(&raw_val);
}
};
struct BitFloat2 {
BitFloat u;
BitFloat v;
};
template <>
struct BytesToType<6> {
using Type = BitFloat2;
static_assert(sizeof(Type) == 6);
};
#endif
namespace {
using transformer_engine::detail::FP8BlockwiseColumnwiseOption;
......@@ -278,12 +194,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
};
extern __shared__ char smem_base[];
#ifdef __HIP_PLATFORM_AMD__
using HipSMemVec = Vec<std::conditional_t<std::is_same_v<IType, float>, BitFloat, IType>, kNVecSMem>;
HipSMemVec* smem = reinterpret_cast<HipSMemVec*>(&smem_base[0]);
#else
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
#endif
// Step 1: Load input to shared memory
{
......@@ -317,22 +228,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem[r * kSMemCol + c].data.elt[j] = BitFloat(input_vec.smem_type.data.elt[i].data.elt[j], pow_2_scaling);
}
}
else
{
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
#else
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
#endif
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
......@@ -374,22 +270,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem_vec[i].data.elt[j] = smem[r * kSMemCol + c].data.elt[j];
}
}
else
{
smem_vec[i] = smem[r * kSMemCol + c];
}
#else
smem_vec[i] = smem[r * kSMemCol + c];
#endif
}
// Step 2.2: Compute local amax
CType amax = 0;
......@@ -405,7 +286,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
......@@ -413,7 +295,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
......@@ -438,13 +320,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
if constexpr(std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
}
else {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] = static_cast<OType>(lroundf(fmaxf(
-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
} else {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
}
}
}
......@@ -494,22 +375,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem_vec[i].data.elt[j] = smem[r * kSMemCol + c].data.elt[j];
}
}
else
{
smem_vec[i] = smem[r * kSMemCol + c];
}
#else
smem_vec[i] = smem[r * kSMemCol + c];
#endif
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
......@@ -522,8 +388,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// Step 3.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#ifdef __HIP_PLATFORM_AMD__
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
......@@ -531,7 +398,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
......@@ -554,13 +421,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
if constexpr(std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale))));
}
else {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f,
fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale))));
} else {
output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
}
}
// Step 3.7: Store output_t
......@@ -679,6 +546,288 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
#ifdef __HIP_PLATFORM_AMD__
constexpr int kFP32SMemCol = kTileDim / kNVecSMem;
constexpr int kFP32SMemSize = kSMemRow * kFP32SMemCol * kNVecSMem;
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel_fp32(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(smem_base);
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory - Column Major
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
// Column Major Store
smem[c * kTileDim + r] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
// Step 2: Cast and store to output_c
if (return_rowwise) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers - Column Major
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
// Column Major Read
smem_vec[i] = smem[c * kTileDim + r];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
CType scale;
// Step 2.4: Compute scale
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim + r_s;
size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] = static_cast<OType>(lroundf(fmaxf(
-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
} else {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
}
}
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if (return_columnwise_transpose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = c_g < num_rows ? min(static_cast<size_t>(kNVecOut), num_rows - c_g)
: 0; // For not aligned case
OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers - Column Major
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
// Column Major Read
smem_vec[i] = smem[c * kTileDim + r];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
// Step 3.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = static_cast<size_t>(blockIdx.y);
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f,
fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale))));
} else {
output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
}
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
#endif
} // namespace
} // namespace transformer_engine
......@@ -767,23 +916,49 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
#ifdef __HIP_PLATFORM_AMD__
using HipSMemType = std::conditional_t<std::is_same_v<InputType, float>, BitFloat, InputType>;
size_t smem_bytes = kSMemSize * sizeof(HipSMemType);
#else
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<InputType, float>) {
size_t smem_bytes = kFP32SMemSize * sizeof(InputType);
if (smem_bytes >= 48 * 1024) {
cudaError_t err =
cudaFuncSetAttribute((const void*)&block_scaled_1d_cast_transpose_kernel_fp32<
kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
block_scaled_1d_cast_transpose_kernel_fp32<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
rowwise_option, columnwise_option, pow2_scale);
} else {
size_t smem_bytes = kSMemSize * sizeof(InputType);
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel<kAligned, float,
InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
rowwise_option, columnwise_option, pow2_scale);
}
#else
size_t smem_bytes = kSMemSize * sizeof(InputType);
#endif
// shared memory must be requested up
if (smem_bytes >= 48 * 1024) {
#ifdef __HIP_PLATFORM_AMD__
cudaError_t err = cudaFuncSetAttribute(
(const void *)&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
#else
cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
#endif
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
......@@ -793,9 +968,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);) // kAligned
) // OutputType
) // InputType
columnwise_option, pow2_scale);
#endif
) // kAligned
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment