Commit f9d870f4 authored by yuguo's avatar yuguo
Browse files
parents 7405fe09 80c5079c
...@@ -116,7 +116,10 @@ class BlockwiseQuantizerReference: ...@@ -116,7 +116,10 @@ class BlockwiseQuantizerReference:
.reshape(M // tile_len, K // tile_len, tile_len**2) .reshape(M // tile_len, K // tile_len, tile_len**2)
.amax(dim=-1) .amax(dim=-1)
).float() ).float()
dtype_max = torch.finfo(quant_dtype).max if quant_dtype == torch.int8:
dtype_max = torch.iinfo(quant_dtype).max
else:
dtype_max = torch.finfo(quant_dtype).max
scale, scale_inv, _ = scale_from_amax_tensor( scale, scale_inv, _ = scale_from_amax_tensor(
x_dtype=x.dtype, x_dtype=x.dtype,
...@@ -152,7 +155,10 @@ class BlockwiseQuantizerReference: ...@@ -152,7 +155,10 @@ class BlockwiseQuantizerReference:
eps: float, eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
M, K = x.shape M, K = x.shape
dtype_max = torch.finfo(quant_dtype).max if quant_dtype == torch.int8:
dtype_max = torch.iinfo(quant_dtype).max
else:
dtype_max = torch.finfo(quant_dtype).max
x_tiled = x.reshape(M, K // tile_len, tile_len) x_tiled = x.reshape(M, K // tile_len, tile_len)
amax_grid = torch.abs(x_tiled).amax(dim=-1).float() amax_grid = torch.abs(x_tiled).amax(dim=-1).float()
scale, scale_inv, _ = scale_from_amax_tensor( scale, scale_inv, _ = scale_from_amax_tensor(
...@@ -272,6 +278,7 @@ class BlockwiseQuantizerReference: ...@@ -272,6 +278,7 @@ class BlockwiseQuantizerReference:
assert quant_dtype in ( assert quant_dtype in (
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2, torch.float8_e5m2,
torch.int8,
), "Unsupported quant dtype." ), "Unsupported quant dtype."
assert quant_tile_shape in ((1, 128), (128, 128)) assert quant_tile_shape in ((1, 128), (128, 128))
......
...@@ -24,7 +24,10 @@ def scale_from_amax_tensor( ...@@ -24,7 +24,10 @@ def scale_from_amax_tensor(
- amax: Amax tensor with updates made for extrema values. - amax: Amax tensor with updates made for extrema values.
""" """
assert amax.dtype == torch.float, "amax must be a float tensor." assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max if quant_dtype == torch.int8:
fp8_max = torch.iinfo(quant_dtype).max
else:
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers # Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps)) amax = torch.max(amax, torch.tensor(eps))
......
...@@ -208,7 +208,7 @@ def check_quantization_block_tiling_versus_reference( ...@@ -208,7 +208,7 @@ def check_quantization_block_tiling_versus_reference(
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) @pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
...@@ -243,7 +243,7 @@ def test_quantization_block_tiling_versus_reference( ...@@ -243,7 +243,7 @@ def test_quantization_block_tiling_versus_reference(
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) @pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
...@@ -274,7 +274,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales( ...@@ -274,7 +274,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
], ],
) )
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) @pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"]) @pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"])
@pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("tile_size", [(128, 128)])
......
...@@ -256,6 +256,7 @@ using int32 = int32_t; ...@@ -256,6 +256,7 @@ using int32 = int32_t;
using int64 = int64_t; using int64 = int64_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
using int8 = int8_t;
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
...@@ -269,6 +270,7 @@ using fp8e5m2 = te_hip_fp8_e5m2; ...@@ -269,6 +270,7 @@ using fp8e5m2 = te_hip_fp8_e5m2;
using fp8e8m0 = __nv_fp8_e8m0; using fp8e8m0 = __nv_fp8_e8m0;
#endif #endif
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
using int8 = int8_t;
namespace detail { namespace detail {
...@@ -311,6 +313,11 @@ struct TypeExtrema<fp8e4m3> { ...@@ -311,6 +313,11 @@ struct TypeExtrema<fp8e4m3> {
#endif #endif
}; };
template <>
struct TypeExtrema<int8> {
static constexpr float max = 127.0f;
};
template <> template <>
struct TypeExtrema<fp8e5m2> { struct TypeExtrema<fp8e5m2> {
static constexpr float max = 57344.0f; static constexpr float max = 57344.0f;
...@@ -337,7 +344,7 @@ struct TypeExtrema { ...@@ -337,7 +344,7 @@ struct TypeExtrema {
template <typename T> template <typename T>
struct TypeInfo { struct TypeInfo {
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8>;
template <typename U, DType current> template <typename U, DType current>
struct Helper { struct Helper {
...@@ -502,6 +509,25 @@ struct TypeInfo { ...@@ -502,6 +509,25 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
......
...@@ -383,6 +383,7 @@ enum class DType { ...@@ -383,6 +383,7 @@ enum class DType {
kFloat8E4M3 = 7, kFloat8E4M3 = 7,
kFloat8E5M2 = 8, kFloat8E5M2 = 8,
kFloat8E8M0 = 9, kFloat8E8M0 = 9,
kInt8 = 10,
kNumTypes kNumTypes
}; };
......
...@@ -328,6 +328,7 @@ using byte = uint8_t; ...@@ -328,6 +328,7 @@ using byte = uint8_t;
using int32 = int32_t; using int32 = int32_t;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
using int8 = int8_t;
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
...@@ -358,6 +359,10 @@ struct TypeToDType<fp8e4m3> { ...@@ -358,6 +359,10 @@ struct TypeToDType<fp8e4m3> {
static constexpr DType value = DType::kFloat8E4M3; static constexpr DType value = DType::kFloat8E4M3;
}; };
template <> template <>
struct TypeToDType<int8> {
static constexpr DType value = DType::kInt8;
};
template <>
struct TypeToDType<fp8e5m2> { struct TypeToDType<fp8e5m2> {
static constexpr DType value = DType::kFloat8E5M2; static constexpr DType value = DType::kFloat8E5M2;
}; };
......
...@@ -533,7 +533,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -533,7 +533,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
output.dtype, OutputType, output.dtype, OutputType,
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
......
...@@ -257,7 +257,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -257,7 +257,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down(amax, delta); const float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else #else
const float other_amax = __shfl_down_sync(mask, amax, delta); const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif #endif
...@@ -266,7 +266,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -266,7 +266,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
amax = __shfl(amax, src_lane); amax = __shfl(amax, src_lane, kThreadsPerWarp);
#else #else
amax = __shfl_sync(mask, amax, src_lane); amax = __shfl_sync(mask, amax, src_lane);
#endif #endif
...@@ -354,7 +354,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -354,7 +354,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll #pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down(amax, delta); const float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else #else
const float other_amax = __shfl_down_sync(mask, amax, delta); const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif #endif
...@@ -363,7 +363,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -363,7 +363,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax = fmaxf(amax, other_amax); amax = fmaxf(amax, other_amax);
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
amax = __shfl(amax, src_lane); amax = __shfl(amax, src_lane, kThreadsPerWarp);
#else #else
amax = __shfl_sync(mask, amax, src_lane); amax = __shfl_sync(mask, amax, src_lane);
#endif #endif
...@@ -479,7 +479,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -479,7 +479,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
output.dtype, OutputType, output.dtype, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
.value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \ pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
......
...@@ -990,6 +990,7 @@ using fp8e4m3 = te_hip_fp8_e4m3; ...@@ -990,6 +990,7 @@ using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2; using fp8e5m2 = te_hip_fp8_e5m2;
#endif #endif
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
using int8 = int8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_EXPONENT_BIAS = 127;
...@@ -1015,6 +1016,12 @@ struct Numeric_Traits<fp8e5m2> { ...@@ -1015,6 +1016,12 @@ struct Numeric_Traits<fp8e5m2> {
static constexpr double maxNorm = 57344; static constexpr double maxNorm = 57344;
}; };
template <>
struct Numeric_Traits<int8> {
static constexpr int maxUnbiasedExponent = 0;
static constexpr double maxNorm = 127;
};
template <typename T> template <typename T>
struct Quantized_Limits { struct Quantized_Limits {
static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent; static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
......
...@@ -18,6 +18,7 @@ TE_DType = { ...@@ -18,6 +18,7 @@ TE_DType = {
torch.uint8: tex.DType.kByte, torch.uint8: tex.DType.kByte,
torch.float8_e4m3fn: tex.DType.kFloat8E4M3, torch.float8_e4m3fn: tex.DType.kFloat8E4M3,
torch.float8_e5m2: tex.DType.kFloat8E5M2, torch.float8_e5m2: tex.DType.kFloat8E5M2,
torch.int8: tex.DType.kInt8,
torch.int32: tex.DType.kInt32, torch.int32: tex.DType.kInt32,
torch.float32: tex.DType.kFloat32, torch.float32: tex.DType.kFloat32,
torch.half: tex.DType.kFloat16, torch.half: tex.DType.kFloat16,
......
...@@ -60,6 +60,8 @@ def check_mxfp8_support() -> Tuple[bool, str]: ...@@ -60,6 +60,8 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def check_fp8_block_scaling_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available""" """Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
return True, ""
if ( if (
get_device_compute_capability() >= (9, 0) get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0) and get_device_compute_capability() < (10, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment