Commit 9a815d0b authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 3d57ff8c e2860c76
...@@ -170,6 +170,8 @@ class BlockwiseQuantizerReference: ...@@ -170,6 +170,8 @@ class BlockwiseQuantizerReference:
) )
qx = x_tiled * scale.reshape(M, K // tile_len, 1) qx = x_tiled * scale.reshape(M, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
if quant_dtype == torch.int8:
qx = torch.round(qx)
qx = qx.to(dtype=quant_dtype) qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K) qx = qx.reshape(M, K)
return qx, scale_inv return qx, scale_inv
......
...@@ -153,7 +153,7 @@ def check_quantization_block_tiling_versus_reference( ...@@ -153,7 +153,7 @@ def check_quantization_block_tiling_versus_reference(
) )
# Check # Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0)
# Zero out values that are don't care values # Zero out values that are don't care values
# Scale format has padding. # Scale format has padding.
scale_mask = torch.ones( scale_mask = torch.ones(
...@@ -163,7 +163,7 @@ def check_quantization_block_tiling_versus_reference( ...@@ -163,7 +163,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx, scale_mask, None, None), tile_size QuantizeResult(qx, scale_mask, None, None), tile_size
).scale ).scale
sx = sx * scale_mask sx = sx * scale_mask
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(sx, sx_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5)
if return_transpose: if return_transpose:
assert qx_t is not None assert qx_t is not None
...@@ -179,8 +179,8 @@ def check_quantization_block_tiling_versus_reference( ...@@ -179,8 +179,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult(qx_t, scale_mask, None, None), tile_size QuantizeResult(qx_t, scale_mask, None, None), tile_size
).scale ).scale
sx_t = sx_t * scale_mask sx_t = sx_t * scale_mask
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0 if quant_dtype != torch.int8 else 1.0, rtol=0.0 if x_dtype != torch.float32 else 2.5e-1)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0 if x_dtype != torch.float32 else 1e-5, rtol=0.0 if x_dtype != torch.float32 else 5e-5)
else: else:
# should be None # should be None
assert qx_t is None and qx_t_ref is None assert qx_t is None and qx_t_ref is None
...@@ -344,6 +344,9 @@ def test_quantization_block_tiling_extrema_versus_reference( ...@@ -344,6 +344,9 @@ def test_quantization_block_tiling_extrema_versus_reference(
torch.testing.assert_close(sx.flatten()[0], sx_ref.flatten()[0], atol=0.0, rtol=0.0) torch.testing.assert_close(sx.flatten()[0], sx_ref.flatten()[0], atol=0.0, rtol=0.0)
if extrema_high: if extrema_high:
if quant_dtype == torch.int8:
expected_value = torch.iinfo(quant_dtype).max / torch.finfo(x_dtype).max
else:
expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max
if pow_2_scales: if pow_2_scales:
expected_value = math.floor(math.log2(expected_value)) expected_value = math.floor(math.log2(expected_value))
......
...@@ -27,6 +27,90 @@ ...@@ -27,6 +27,90 @@
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__
__device__ bool is_little_endian()
{
int num = 1;
const char* ptr = reinterpret_cast<const char*>(&num);
if(*ptr == 1)
{
return true;
}
else
{
return false;
}
}
struct BitFloat
{
private:
char data[3];
public:
__device__ BitFloat(const float val, bool pow2scale)
{
uint32_t raw_val = *reinterpret_cast<const uint32_t*>(&val);
if (~raw_val & 0x7f800000)
{
if(pow2scale && (raw_val & 0x000000FF))
{
raw_val |= 0x100;
}
else
{
raw_val += 0x7f + ((raw_val >> 8) & 1);
}
}
else if (raw_val & 0xffff)
{
raw_val |= 0x100;
}
raw_val = (raw_val >> 8);
const char* ptr = reinterpret_cast<const char*>(&raw_val);
if(is_little_endian())
{
data[0] = ptr[0];
data[1] = ptr[1];
data[2] = ptr[2];
}
else
{
data[0] = ptr[1];
data[1] = ptr[2];
data[2] = ptr[3];
}
}
__device__ operator float() const
{
uint32_t raw_val = 0;
char* ptr = reinterpret_cast<char*>(&raw_val);
if(is_little_endian())
{
ptr[1] = data[0];
ptr[2] = data[1];
ptr[3] = data[2];
}
else
{
ptr[0] = data[0];
ptr[1] = data[1];
ptr[2] = data[2];
}
return *reinterpret_cast<const float*>(&raw_val);
}
};
struct BitFloat2 {
BitFloat u;
BitFloat v;
};
template <>
struct BytesToType<6> {
using Type = BitFloat2;
static_assert(sizeof(Type) == 6);
};
#endif
namespace { namespace {
using transformer_engine::detail::FP8BlockwiseColumnwiseOption; using transformer_engine::detail::FP8BlockwiseColumnwiseOption;
...@@ -169,7 +253,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -169,7 +253,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
extern __shared__ char smem_base[]; extern __shared__ char smem_base[];
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
using HipSMemVec = Vec<std::conditional_t<std::is_same_v<IType, float>, __hip_bfloat16, IType>, kNVecSMem>; using HipSMemVec = Vec<std::conditional_t<std::is_same_v<IType, float>, BitFloat, IType>, kNVecSMem>;
HipSMemVec* smem = reinterpret_cast<HipSMemVec*>(&smem_base[0]); HipSMemVec* smem = reinterpret_cast<HipSMemVec*>(&smem_base[0]);
#else #else
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]); SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
...@@ -213,14 +297,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -213,14 +297,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll #pragma unroll
for(int j = 0; j < kNVecSMem; ++j) for(int j = 0; j < kNVecSMem; ++j)
{ {
uint32_t raw_val = *reinterpret_cast<const uint32_t*>(&input_vec.smem_type.data.elt[i].data.elt[j]); smem[r * kSMemCol + c].data.elt[j] = BitFloat(input_vec.smem_type.data.elt[i].data.elt[j], pow_2_scaling);
// [Workaround] Under certain critical conditions, scale will be 2 * ref_scale because of float -> bfloat16.
// We use carry over here to avoid this issue.
if(pow_2_scaling && (raw_val & 0x0000FFFF))
{
raw_val |= 0x00010000;
}
smem[r * kSMemCol + c].data.elt[j] = *reinterpret_cast<const float*>(&raw_val);
} }
} }
else else
...@@ -335,10 +412,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -335,10 +412,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll #pragma unroll
for (int j = 0; j < kNVecSMem; ++j) { for (int j = 0; j < kNVecSMem; ++j) {
if constexpr(std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
}
else {
output_vec.data.elt[i * kNVecSMem + j] = output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale); static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
} }
} }
}
// Step 2.7: Store output_c // Step 2.7: Store output_c
if constexpr (kAligned) { if constexpr (kAligned) {
output_vec.store_to(output_g); output_vec.store_to(output_g);
...@@ -445,9 +528,15 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -445,9 +528,15 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec output_vec; OVec output_vec;
#pragma unroll #pragma unroll
for (int i = 0; i < kNVecOut; ++i) { for (int i = 0; i < kNVecOut; ++i) {
if constexpr(std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale))));
}
else {
output_vec.data.elt[i] = output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale); static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
} }
}
// Step 3.7: Store output_t // Step 3.7: Store output_t
if constexpr (kAligned) { if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows); output_vec.store_to(output_g + smem_idx * num_rows);
...@@ -550,7 +639,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -550,7 +639,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
full_tile, kAligned, full_tile, kAligned,
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
using HipSMemType = std::conditional_t<std::is_same_v<InputType, float>, hip_bfloat16, InputType>; using HipSMemType = std::conditional_t<std::is_same_v<InputType, float>, BitFloat, InputType>;
size_t smem_bytes = kSMemSize * sizeof(HipSMemType); size_t smem_bytes = kSMemSize * sizeof(HipSMemType);
#else #else
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment