Commit f9d870f4 authored by yuguo's avatar yuguo
Browse files
parents 7405fe09 80c5079c
......@@ -116,6 +116,9 @@ class BlockwiseQuantizerReference:
.reshape(M // tile_len, K // tile_len, tile_len**2)
.amax(dim=-1)
).float()
if quant_dtype == torch.int8:
dtype_max = torch.iinfo(quant_dtype).max
else:
dtype_max = torch.finfo(quant_dtype).max
scale, scale_inv, _ = scale_from_amax_tensor(
......@@ -152,6 +155,9 @@ class BlockwiseQuantizerReference:
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
M, K = x.shape
if quant_dtype == torch.int8:
dtype_max = torch.iinfo(quant_dtype).max
else:
dtype_max = torch.finfo(quant_dtype).max
x_tiled = x.reshape(M, K // tile_len, tile_len)
amax_grid = torch.abs(x_tiled).amax(dim=-1).float()
......@@ -272,6 +278,7 @@ class BlockwiseQuantizerReference:
assert quant_dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.int8,
), "Unsupported quant dtype."
assert quant_tile_shape in ((1, 128), (128, 128))
......
......@@ -24,6 +24,9 @@ def scale_from_amax_tensor(
- amax: Amax tensor with updates made for extrema values.
"""
assert amax.dtype == torch.float, "amax must be a float tensor."
if quant_dtype == torch.int8:
fp8_max = torch.iinfo(quant_dtype).max
else:
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
......
......@@ -208,7 +208,7 @@ def check_quantization_block_tiling_versus_reference(
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
......@@ -243,7 +243,7 @@ def test_quantization_block_tiling_versus_reference(
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
......@@ -274,7 +274,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"])
@pytest.mark.parametrize("tile_size", [(128, 128)])
......
......@@ -256,6 +256,7 @@ using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp16 = half;
using int8 = int8_t;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
......@@ -269,6 +270,7 @@ using fp8e5m2 = te_hip_fp8_e5m2;
using fp8e8m0 = __nv_fp8_e8m0;
#endif
using e8m0_t = uint8_t;
using int8 = int8_t;
namespace detail {
......@@ -311,6 +313,11 @@ struct TypeExtrema<fp8e4m3> {
#endif
};
template <>
struct TypeExtrema<int8> {
static constexpr float max = 127.0f;
};
template <>
struct TypeExtrema<fp8e5m2> {
static constexpr float max = 57344.0f;
......@@ -337,7 +344,7 @@ struct TypeExtrema {
template <typename T>
struct TypeInfo {
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, int8>;
template <typename U, DType current>
struct Helper {
......@@ -502,6 +509,25 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......
......@@ -383,6 +383,7 @@ enum class DType {
kFloat8E4M3 = 7,
kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
kInt8 = 10,
kNumTypes
};
......
......@@ -328,6 +328,7 @@ using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;
using int8 = int8_t;
#ifndef __HIP_PLATFORM_AMD__
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
......@@ -358,6 +359,10 @@ struct TypeToDType<fp8e4m3> {
static constexpr DType value = DType::kFloat8E4M3;
};
template <>
struct TypeToDType<int8> {
static constexpr DType value = DType::kInt8;
};
template <>
struct TypeToDType<fp8e5m2> {
static constexpr DType value = DType::kFloat8E5M2;
};
......
......@@ -533,7 +533,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
output.dtype, OutputType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
......
......@@ -257,7 +257,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down(amax, delta);
const float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
......@@ -266,7 +266,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl(amax, src_lane);
amax = __shfl(amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
......@@ -354,7 +354,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax = __shfl_down(amax, delta);
const float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
......@@ -363,7 +363,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl(amax, src_lane);
amax = __shfl(amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
......@@ -479,7 +479,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
output.dtype, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
......
......@@ -26,7 +26,8 @@
.value("kFloat16", transformer_engine::DType::kFloat16) \
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
......
......@@ -990,6 +990,7 @@ using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
using e8m0_t = uint8_t;
using int8 = int8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
......@@ -1015,6 +1016,12 @@ struct Numeric_Traits<fp8e5m2> {
static constexpr double maxNorm = 57344;
};
template <>
struct Numeric_Traits<int8> {
static constexpr int maxUnbiasedExponent = 0;
static constexpr double maxNorm = 127;
};
template <typename T>
struct Quantized_Limits {
static constexpr int max_unbiased_exponent = Numeric_Traits<T>::maxUnbiasedExponent;
......
......@@ -18,6 +18,7 @@ TE_DType = {
torch.uint8: tex.DType.kByte,
torch.float8_e4m3fn: tex.DType.kFloat8E4M3,
torch.float8_e5m2: tex.DType.kFloat8E5M2,
torch.int8: tex.DType.kInt8,
torch.int32: tex.DType.kInt32,
torch.float32: tex.DType.kFloat32,
torch.half: tex.DType.kFloat16,
......
......@@ -60,6 +60,8 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
return True, ""
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment