Commit ccb9a1b1 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.4' into 'main'

[DCU] surpport cast master weight to int8

See merge request dcutoolkit/deeplearing/TransformerEngine!23
parents 2cbe1b70 0a8072fa
......@@ -65,9 +65,9 @@ def setup_common_extension() -> CMakeExtension:
cmake_flags = []
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "0"))):
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "1"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "0"))):
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "1"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
else:
......
......@@ -9,7 +9,7 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# NVTE_INT8_SIM_FP8=1 torchrun --nproc_per_node=4 run_cast_master_weights_to_fp8.py --quantization fp8_block
if torch.cuda.device_count() < 2:
pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
......
......@@ -125,6 +125,9 @@ if(USE_CUDA)
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
......@@ -165,6 +168,9 @@ else()
cudnn_utils.cpp
transformer_engine.cpp
common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
......@@ -222,6 +228,7 @@ else()
set(header_include_dir
${CMAKE_CURRENT_SOURCE_DIR}/comm_gemm_overlap/userbuffers
${CMAKE_CURRENT_SOURCE_DIR}/activation
${CMAKE_CURRENT_SOURCE_DIR}/fused_attn
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/transpose
${CMAKE_CURRENT_SOURCE_DIR}/util
......@@ -234,7 +241,6 @@ else()
hipify(CUDA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}
HEADER_INCLUDE_DIR ${header_include_dir}
IGNORES "*/amd_detail/*"
IGNORES "*/fused_attn/*"
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
)
get_hipified_list("${transformer_engine_SOURCES}" te_hip_sources)
......
......@@ -280,6 +280,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
TRANSFORMER_ENGINE_TYPE_NAME(int8_t)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
......@@ -455,6 +456,37 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt8: { \
using type = int8; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......
......@@ -53,7 +53,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
float other_amax = __shfl_down(amax, delta);
float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
#endif
......@@ -124,14 +124,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
bool other_skip_store = __shfl_down(skip_store, delta);
bool other_skip_store = __shfl_down(skip_store, delta, kThreadsPerWarp);
#else
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
#endif
skip_store = skip_store && other_skip_store;
}
#ifdef __HIP_PLATFORM_AMD__
skip_store = __shfl(skip_store, 0);
skip_store = __shfl(skip_store, 0, kThreadsPerWarp);
#else
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
#endif
......@@ -217,7 +217,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned,
......
......@@ -211,7 +211,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK(noop.data.dptr != nullptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT_WITH_INT8(
input.data.dtype, Type, constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type);
......
......@@ -60,7 +60,6 @@ def general_gemm(
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert not accumulate, "Accumulation not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
......@@ -80,6 +79,11 @@ def general_gemm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None
elif layout == "NN":
......@@ -96,6 +100,11 @@ def general_gemm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None
elif layout == "NT":
......@@ -112,6 +121,11 @@ def general_gemm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=out_dtype
)
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None, None
else:
......@@ -203,7 +217,6 @@ def general_grouped_gemm(
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert bias is None, "Bias not supported with int8 simulation groupgemm."
assert not accumulate, "Accumulation not supported with int8 simulation groupgemm."
if layout == "TN":
qx_data = [
......@@ -219,6 +232,11 @@ def general_grouped_gemm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None
elif layout == "NN":
......@@ -235,6 +253,11 @@ def general_grouped_gemm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None
elif layout == "NT":
......@@ -251,6 +274,11 @@ def general_grouped_gemm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=out_dtype
)
if accumulate:
assert out is not None
y = y + out
else:
assert out is None, "Output tensor should be None when accumulate is False."
return y, None, None
else:
......
......@@ -36,8 +36,9 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
out_dtype == transformer_engine::DType::kFloat8E5M2 ||
out_dtype == transformer_engine::DType::kInt8,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2 or kInt8");
const TensorWrapper inp_cu = makeTransformerEngineTensor(inp);
TensorWrapper out_cu = makeTransformerEngineTensor(out);
......
......@@ -414,6 +414,8 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
max_fp8 = 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0
elif fp8_dtype == tex.DType.kInt8:
max_fp8 = 127.0
else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
multi_tensor_applier(
......@@ -435,7 +437,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) # May cause core dump in iter 2
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment