Commit ccb9a1b1 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.4' into 'main'

[DCU] surpport cast master weight to int8

See merge request dcutoolkit/deeplearing/TransformerEngine!23
parents 2cbe1b70 0a8072fa
...@@ -65,9 +65,9 @@ def setup_common_extension() -> CMakeExtension: ...@@ -65,9 +65,9 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags = [] cmake_flags = []
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))): if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON") cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "0"))): if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON") cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "0"))): if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON") cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
else: else:
......
...@@ -9,7 +9,7 @@ from pathlib import Path ...@@ -9,7 +9,7 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
if torch.cuda.device_count() < 2: if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
......
...@@ -125,6 +125,9 @@ if(USE_CUDA) ...@@ -125,6 +125,9 @@ if(USE_CUDA)
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu activation/relu.cu
...@@ -165,6 +168,9 @@ else() ...@@ -165,6 +168,9 @@ else()
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
multi_tensor/adam.cu multi_tensor/adam.cu
multi_tensor/compute_scale.cu multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu multi_tensor/l2norm.cu
...@@ -222,6 +228,7 @@ else() ...@@ -222,6 +228,7 @@ else()
set(header_include_dir set(header_include_dir
${CMAKE_CURRENT_SOURCE_DIR}/comm_gemm_overlap/userbuffers ${CMAKE_CURRENT_SOURCE_DIR}/comm_gemm_overlap/userbuffers
${CMAKE_CURRENT_SOURCE_DIR}/activation ${CMAKE_CURRENT_SOURCE_DIR}/activation
${CMAKE_CURRENT_SOURCE_DIR}/fused_attn
${CMAKE_CURRENT_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/transpose ${CMAKE_CURRENT_SOURCE_DIR}/transpose
${CMAKE_CURRENT_SOURCE_DIR}/util ${CMAKE_CURRENT_SOURCE_DIR}/util
...@@ -234,7 +241,6 @@ else() ...@@ -234,7 +241,6 @@ else()
hipify(CUDA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR} hipify(CUDA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}
HEADER_INCLUDE_DIR ${header_include_dir} HEADER_INCLUDE_DIR ${header_include_dir}
IGNORES "*/amd_detail/*" IGNORES "*/amd_detail/*"
IGNORES "*/fused_attn/*"
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json" CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
) )
get_hipified_list("${transformer_engine_SOURCES}" te_hip_sources) get_hipified_list("${transformer_engine_SOURCES}" te_hip_sources)
......
...@@ -280,6 +280,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(half) ...@@ -280,6 +280,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
TRANSFORMER_ENGINE_TYPE_NAME(int8_t)
#if CUDA_VERSION >= 12080 #if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif #endif
...@@ -455,6 +456,37 @@ struct TypeInfo { ...@@ -455,6 +456,37 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
......
...@@ -53,7 +53,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -53,7 +53,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
float other_amax = __shfl_down(amax, delta); float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else #else
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta); float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
#endif #endif
...@@ -124,14 +124,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -124,14 +124,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
bool other_skip_store = __shfl_down(skip_store, delta); bool other_skip_store = __shfl_down(skip_store, delta, kThreadsPerWarp);
#else #else
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta); bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
#endif #endif
skip_store = skip_store && other_skip_store; skip_store = skip_store && other_skip_store;
} }
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
skip_store = __shfl(skip_store, 0); skip_store = __shfl(skip_store, 0, kThreadsPerWarp);
#else #else
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0); skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
#endif #endif
...@@ -217,7 +217,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s ...@@ -217,7 +217,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype, inp.dtype(), inp_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
out_dtype, fp8_type, out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned, w % kTileDim == 0, kWidthAligned,
......
...@@ -211,7 +211,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr ...@@ -211,7 +211,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK(noop.data.dptr != nullptr); NVTE_CHECK(noop.data.dptr != nullptr);
} }
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
input.data.dtype, Type, constexpr const char *type_name = TypeInfo<Type>::name; input.data.dtype, Type, constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type); constexpr size_t type_size = sizeof(Type);
......
...@@ -60,7 +60,6 @@ def general_gemm( ...@@ -60,7 +60,6 @@ def general_gemm(
assert not gelu, "GELU not supported with int8 simulation" assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation" assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation" assert bias is None, "Bias not supported with int8 simulation"
assert not accumulate, "Accumulation not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation" assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation" assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation" assert extra_output is None, "Extra output not supported with int8 simulation"
...@@ -80,6 +79,11 @@ def general_gemm( ...@@ -80,6 +79,11 @@ def general_gemm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None return y, None, None, None
elif layout == "NN": elif layout == "NN":
...@@ -96,6 +100,11 @@ def general_gemm( ...@@ -96,6 +100,11 @@ def general_gemm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None return y, None, None, None
elif layout == "NT": elif layout == "NT":
...@@ -112,6 +121,11 @@ def general_gemm( ...@@ -112,6 +121,11 @@ def general_gemm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None return y, None, None, None
else: else:
...@@ -203,7 +217,6 @@ def general_grouped_gemm( ...@@ -203,7 +217,6 @@ def general_grouped_gemm(
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now." assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm." assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert bias is None, "Bias not supported with int8 simulation groupgemm." assert bias is None, "Bias not supported with int8 simulation groupgemm."
assert not accumulate, "Accumulation not supported with int8 simulation groupgemm."
if layout == "TN": if layout == "TN":
qx_data = [ qx_data = [
...@@ -219,6 +232,11 @@ def general_grouped_gemm( ...@@ -219,6 +232,11 @@ def general_grouped_gemm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None return y, None, None
elif layout == "NN": elif layout == "NN":
...@@ -235,6 +253,11 @@ def general_grouped_gemm( ...@@ -235,6 +253,11 @@ def general_grouped_gemm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None return y, None, None
elif layout == "NT": elif layout == "NT":
...@@ -251,6 +274,11 @@ def general_grouped_gemm( ...@@ -251,6 +274,11 @@ def general_grouped_gemm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None return y, None, None
else: else:
......
...@@ -36,8 +36,9 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const ...@@ -36,8 +36,9 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
"input must be a float or bfloat16 tensor"); "input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor"); TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 || TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2, out_dtype == transformer_engine::DType::kFloat8E5M2 ||
"out_dtype must be kFloat8E4M3 or kFloat8E5M2"); out_dtype == transformer_engine::DType::kInt8,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2 or kInt8");
const TensorWrapper inp_cu = makeTransformerEngineTensor(inp); const TensorWrapper inp_cu = makeTransformerEngineTensor(inp);
TensorWrapper out_cu = makeTransformerEngineTensor(out); TensorWrapper out_cu = makeTransformerEngineTensor(out);
......
...@@ -414,6 +414,8 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ...@@ -414,6 +414,8 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
max_fp8 = 448.0 max_fp8 = 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2: elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0 max_fp8 = 57344.0
elif fp8_dtype == tex.DType.kInt8:
max_fp8 = 127.0
else: else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}") raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
multi_tensor_applier( multi_tensor_applier(
...@@ -435,7 +437,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ...@@ -435,7 +437,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
# We cannot create columnwise data here because users (like megatron) may want to overlap # We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated # the all-gather of model weights and forward process, so the model weight is not updated
# at this moment. # at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) # May cause core dump in iter 2
# If master weight is None, it means that the master weight of the current model weight # If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks. # is in other DP ranks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment