Commit 1f9c104b authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 2b1428ff 8a03ff34
...@@ -263,30 +263,16 @@ void compare_scaling_factors(const std::string& name, const float* test, const f ...@@ -263,30 +263,16 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test,
const float* ref, const size_t rows, const float* ref, const size_t rows,
const size_t col_blocks const size_t col_blocks) {
#ifdef __HIP_PLATFORM_AMD__
, double atol = 0., double rtol = 0.
#endif
) {
const size_t test_stride = scale_align_stride(rows); const size_t test_stride = scale_align_stride(rows);
for (int i = 0; i < rows; ++i) { for (int i = 0; i < rows; ++i) {
for (int j = 0; j < col_blocks; ++j) { for (int j = 0; j < col_blocks; ++j) {
const int test_idx = i + test_stride * j; const int test_idx = i + test_stride * j;
const int ref_idx = i + rows * j; const int ref_idx = i + rows * j;
#ifdef __HIP_PLATFORM_AMD__
double t = static_cast<double>(static_cast<float>(test[test_idx]));
double r = static_cast<double>(static_cast<float>(ref[ref_idx]));
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
ASSERT_FALSE(mismatch)
<< "Error in " << name << std::endl
<< "Mismatch: " << t << " vs " << r << " at index " << test_idx
<< "," << ref_idx;
#else
ASSERT_FALSE(test[test_idx] != ref[ref_idx]) ASSERT_FALSE(test[test_idx] != ref[ref_idx])
<< "Error in " << name << std::endl << "Error in " << name << std::endl
<< "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx
<< "," << ref_idx; << "," << ref_idx;
#endif
} }
} }
} }
...@@ -425,33 +411,17 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, ...@@ -425,33 +411,17 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
float atol = 0.0; float atol = 0.0;
float rtol = 0.0; float rtol = 0.0;
#ifdef __HIP_PLATFORM_AMD__
double atol_scale = 0.0;
double rtol_scale = 0.0;
if(itype == DType::kFloat32)
{
atol_scale = 1e-5;
}
#endif
if (rowwise) { if (rowwise) {
compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); compareResults("output_c", output_c, ref_output.get(), true, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv", compare_scaling_factors_one_dimensional_blocks("scale_inv",
output_c.rowwise_cpu_scale_inv_ptr<float>(), output_c.rowwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv.get(), rows, blocks_x ref_scale_inv.get(), rows, blocks_x);
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
} }
if (colwise) { if (colwise) {
compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol);
compare_scaling_factors_one_dimensional_blocks("scale_inv_t", compare_scaling_factors_one_dimensional_blocks("scale_inv_t",
output_c.columnwise_cpu_scale_inv_ptr<float>(), output_c.columnwise_cpu_scale_inv_ptr<float>(),
ref_scale_inv_t.get(), cols, blocks_x_t ref_scale_inv_t.get(), cols, blocks_x_t);
#ifdef __HIP_PLATFORM_AMD__
, atol_scale, rtol_scale
#endif
);
} }
} }
......
...@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference: ...@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
qx = x_tiled * scale.reshape(M, K // tile_len, 1) qx = x_tiled * scale.reshape(M, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
if quant_dtype == torch.int8: if quant_dtype == torch.int8:
qx = torch.round(qx) positive_mask = qx >= 0
negative_mask = ~positive_mask
pos_part = torch.where(positive_mask, torch.floor(qx + 0.5), 0)
neg_part = torch.where(negative_mask, torch.ceil(qx - 0.5), 0)
qx = pos_part + neg_part
qx = qx.to(dtype=quant_dtype) qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K) qx = qx.reshape(M, K)
return qx, scale_inv return qx, scale_inv
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from typing import Tuple from typing import Tuple
import torch import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def scale_from_amax_tensor( def scale_from_amax_tensor(
x_dtype: torch.dtype, x_dtype: torch.dtype,
...@@ -48,7 +48,11 @@ def scale_from_amax_tensor( ...@@ -48,7 +48,11 @@ def scale_from_amax_tensor(
# No subnormals and zero. # No subnormals and zero.
assert (exp > -127).all() assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device) unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale) if IS_HIP_EXTENSION:
host_scale = torch.ldexp(unity.cpu(), exp.cpu())
scale = host_scale.to(exp.device)
else:
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale # Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation. # calculation.
......
...@@ -273,7 +273,7 @@ def check_quantization_block_tiling_versus_reference( ...@@ -273,7 +273,7 @@ def check_quantization_block_tiling_versus_reference(
) )
# Check # Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0) torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
# Zero out values that are don't care values # Zero out values that are don't care values
# Scale format has padding. # Scale format has padding.
scale_mask = torch.ones( scale_mask = torch.ones(
...@@ -283,7 +283,7 @@ def check_quantization_block_tiling_versus_reference( ...@@ -283,7 +283,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx, scale_mask, None, None), tile_size QuantizeResult(qx, scale_mask, None, None), tile_size
).scale ).scale
sx = sx * scale_mask sx = sx * scale_mask
torch.testing.assert_close(sx, sx_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5) torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
if return_transpose: if return_transpose:
assert qx_t is not None assert qx_t is not None
...@@ -299,8 +299,8 @@ def check_quantization_block_tiling_versus_reference( ...@@ -299,8 +299,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx_t, scale_mask, None, None), tile_size QuantizeResult(qx_t, scale_mask, None, None), tile_size
).scale ).scale
sx_t = sx_t * scale_mask sx_t = sx_t * scale_mask
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0 if x_dtype != torch.float32 else 2.5e-1) torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5) torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
else: else:
# should be None # should be None
assert qx_t is None and qx_t_ref is None assert qx_t is None and qx_t_ref is None
......
...@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// Step 3: Store cast output // Step 3: Store cast output
CType scale_data = block_tile_scale; CType scale_data = block_tile_scale;
OType scaled_elt = OType scaled_elt = 0;
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data); if constexpr(std::is_same_v<OType, int8_t>) {
scaled_elt =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
}
else {
scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
}
tmp_output_c.data.elt[j] = scaled_elt; tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile // Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) { if constexpr (kReturnTranspose) {
......
...@@ -27,90 +27,6 @@ ...@@ -27,90 +27,6 @@
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
__device__ bool is_little_endian()
{
int num = 1;
const char* ptr = reinterpret_cast<const char*>(&num);
if(*ptr == 1)
{
return true;
}
else
{
return false;
}
}
struct BitFloat
{
private:
char data[3];
public:
__device__ BitFloat(const float val, bool pow2scale)
{
uint32_t raw_val = *reinterpret_cast<const uint32_t*>(&val);
if (~raw_val & 0x7f800000)
{
if(pow2scale && (raw_val & 0x000000FF))
{
raw_val |= 0x100;
}
else
{
raw_val += 0x7f + ((raw_val >> 8) & 1);
}
}
else if (raw_val & 0xffff)
{
raw_val |= 0x100;
}
raw_val = (raw_val >> 8);
const char* ptr = reinterpret_cast<const char*>(&raw_val);
if(is_little_endian())
{
data[0] = ptr[0];
data[1] = ptr[1];
data[2] = ptr[2];
}
else
{
data[0] = ptr[1];
data[1] = ptr[2];
data[2] = ptr[3];
}
}
__device__ operator float() const
{
uint32_t raw_val = 0;
char* ptr = reinterpret_cast<char*>(&raw_val);
if(is_little_endian())
{
ptr[1] = data[0];
ptr[2] = data[1];
ptr[3] = data[2];
}
else
{
ptr[0] = data[0];
ptr[1] = data[1];
ptr[2] = data[2];
}
return *reinterpret_cast<const float*>(&raw_val);
}
};
struct BitFloat2 {
BitFloat u;
BitFloat v;
};
template <>
struct BytesToType<6> {
using Type = BitFloat2;
static_assert(sizeof(Type) == 6);
};
#endif
namespace { namespace {
using transformer_engine::detail::FP8BlockwiseColumnwiseOption; using transformer_engine::detail::FP8BlockwiseColumnwiseOption;
...@@ -278,12 +194,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -278,12 +194,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}; };
extern __shared__ char smem_base[]; extern __shared__ char smem_base[];
#ifdef __HIP_PLATFORM_AMD__
using HipSMemVec = Vec<std::conditional_t<std::is_same_v<IType, float>, BitFloat, IType>, kNVecSMem>;
HipSMemVec* smem = reinterpret_cast<HipSMemVec*>(&smem_base[0]);
#else
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]); SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
#endif
// Step 1: Load input to shared memory // Step 1: Load input to shared memory
{ {
...@@ -317,22 +228,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -317,22 +228,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i; int c = c_s + i;
int r = r_s; int r = r_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem[r * kSMemCol + c].data.elt[j] = BitFloat(input_vec.smem_type.data.elt[i].data.elt[j], pow_2_scaling);
}
}
else
{
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
#else
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
#endif
} }
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case) // Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g; input_g += stride_g;
...@@ -374,22 +270,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -374,22 +270,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i; int c = c_s + i;
int r = r_s; int r = r_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem_vec[i].data.elt[j] = smem[r * kSMemCol + c].data.elt[j];
}
}
else
{
smem_vec[i] = smem[r * kSMemCol + c];
}
#else
smem_vec[i] = smem[r * kSMemCol + c]; smem_vec[i] = smem[r * kSMemCol + c];
#endif
} }
// Step 2.2: Compute local amax // Step 2.2: Compute local amax
CType amax = 0; CType amax = 0;
...@@ -405,7 +286,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -405,7 +286,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp); const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else #else
const float other_amax = __shfl_down_sync(mask, amax, delta); const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif #endif
...@@ -413,7 +295,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -413,7 +295,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__builtin_assume(other_amax >= 0); __builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp); amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else #else
amax = __shfl_sync(mask, amax, src_lane); amax = __shfl_sync(mask, amax, src_lane);
...@@ -438,13 +320,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -438,13 +320,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll #pragma unroll
for (int j = 0; j < kNVecSMem; ++j) { for (int j = 0; j < kNVecSMem; ++j) {
if constexpr(std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] = output_vec.data.elt[i * kNVecSMem + j] = static_cast<OType>(lroundf(fmaxf(
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale)))); -127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
} } else {
else {
output_vec.data.elt[i * kNVecSMem + j] = output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale); static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
} }
} }
} }
...@@ -494,22 +375,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -494,22 +375,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut; ++i) { for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i; int r = r_s + i;
int c = c_s; int c = c_s;
#ifdef __HIP_PLATFORM_AMD__
if constexpr (std::is_same_v<IType, float>)
{
#pragma unroll
for(int j = 0; j < kNVecSMem; ++j)
{
smem_vec[i].data.elt[j] = smem[r * kSMemCol + c].data.elt[j];
}
}
else
{
smem_vec[i] = smem[r * kSMemCol + c];
}
#else
smem_vec[i] = smem[r * kSMemCol + c]; smem_vec[i] = smem[r * kSMemCol + c];
#endif
} }
#pragma unroll #pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
...@@ -522,8 +388,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -522,8 +388,9 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
// Step 3.3: Reduce amax // Step 3.3: Reduce amax
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp); const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else #else
const float other_amax = __shfl_down_sync(mask, amax, delta); const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif #endif
...@@ -531,7 +398,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -531,7 +398,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__builtin_assume(other_amax >= 0); __builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp); amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else #else
amax = __shfl_sync(mask, amax, src_lane); amax = __shfl_sync(mask, amax, src_lane);
...@@ -554,13 +421,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -554,13 +421,13 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec output_vec; OVec output_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < kNVecOut; ++i) { for (int i = 0; i < kNVecOut; ++i) {
if constexpr(std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = output_vec.data.elt[i] = static_cast<OType>(lroundf(
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale)))); fmaxf(-127.0f,
} fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale))));
else { } else {
output_vec.data.elt[i] = output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale); static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
} }
} }
// Step 3.7: Store output_t // Step 3.7: Store output_t
...@@ -679,6 +546,288 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -679,6 +546,288 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
#ifdef __HIP_PLATFORM_AMD__
constexpr int kFP32SMemCol = kTileDim / kNVecSMem;
constexpr int kFP32SMemSize = kSMemRow * kFP32SMemCol * kNVecSMem;
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel_fp32(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(smem_base);
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory - Column Major
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
// Column Major Store
smem[c * kTileDim + r] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
// Step 2: Cast and store to output_c
if (return_rowwise) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers - Column Major
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
// Column Major Read
smem_vec[i] = smem[c * kTileDim + r];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
CType scale;
// Step 2.4: Compute scale
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim + r_s;
size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] = static_cast<OType>(lroundf(fmaxf(
-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
} else {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
}
}
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if (return_columnwise_transpose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = c_g < num_rows ? min(static_cast<size_t>(kNVecOut), num_rows - c_g)
: 0; // For not aligned case
OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers - Column Major
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
// Column Major Read
smem_vec[i] = smem[c * kTileDim + r];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
// Step 3.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = static_cast<size_t>(blockIdx.y);
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f,
fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale))));
} else {
output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
}
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
#endif
} // namespace } // namespace
} // namespace transformer_engine } // namespace transformer_engine
...@@ -767,23 +916,49 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -767,23 +916,49 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned, full_tile, kAligned,
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
using HipSMemType = std::conditional_t<std::is_same_v<InputType, float>, BitFloat, InputType>; if constexpr (std::is_same_v<InputType, float>) {
size_t smem_bytes = kSMemSize * sizeof(HipSMemType); size_t smem_bytes = kFP32SMemSize * sizeof(InputType);
#else if (smem_bytes >= 48 * 1024) {
cudaError_t err =
cudaFuncSetAttribute((const void*)&block_scaled_1d_cast_transpose_kernel_fp32<
kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
block_scaled_1d_cast_transpose_kernel_fp32<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
rowwise_option, columnwise_option, pow2_scale);
} else {
size_t smem_bytes = kSMemSize * sizeof(InputType);
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel<kAligned, float,
InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
rowwise_option, columnwise_option, pow2_scale);
}
#else
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
#endif
// shared memory must be requested up // shared memory must be requested up
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
#ifdef __HIP_PLATFORM_AMD__
cudaError_t err = cudaFuncSetAttribute(
(const void *)&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
#else
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>, &block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
#endif
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType> } block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>( <<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
...@@ -793,9 +968,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -793,9 +968,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);) // kAligned columnwise_option, pow2_scale);
) // OutputType #endif
) // InputType ) // kAligned
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment