Commit e32965ff authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.8' into 'main'

Fix user args core dump in mt

See merge request dcutoolkit/deeplearing/TransformerEngine!57
parents 4b65dfa3 a13c52ad
...@@ -17,6 +17,7 @@ import transformer_engine_torch as tex ...@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import ( from test_numerics import (
_emulate_linear, _emulate_linear,
...@@ -47,7 +48,6 @@ TEST_NR = 0 ...@@ -47,7 +48,6 @@ TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None: if tp_size is None:
tp_size = WORLD_SIZE tp_size = WORLD_SIZE
...@@ -72,6 +72,16 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, ...@@ -72,6 +72,16 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"): def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"):
if IS_HIP_EXTENSION:
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
name=name,
bias=False,
parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
)
else:
model = transformer_engine.pytorch.Linear( model = transformer_engine.pytorch.Linear(
IN_SIZE, IN_SIZE,
OUT_SIZE, OUT_SIZE,
...@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs): ...@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
) )
set_weight_tensor_tp_group_reduce(True) # reset set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test @run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs): def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG from test_log import LOG_QUANTIZED_CONFIG
...@@ -580,6 +589,9 @@ def test_fake_quant_fp8( ...@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad), "dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input), "wgrad_fp8": not (wgrad_grad or wgrad_input),
} }
if IS_HIP_EXTENSION:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0: if WORLD_RANK == 0:
fake_quant_fp8_create_config( fake_quant_fp8_create_config(
fprop_inp, fprop_inp,
...@@ -667,6 +679,10 @@ if __name__ == "__main__": ...@@ -667,6 +679,10 @@ if __name__ == "__main__":
random.seed(SEED) random.seed(SEED)
_init_distributed() _init_distributed()
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
test_log_expert_parallel() test_log_expert_parallel()
for parallel_mode in ["column", "row"]: for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]: for gather_weight in [True, False]:
...@@ -676,6 +692,11 @@ if __name__ == "__main__": ...@@ -676,6 +692,11 @@ if __name__ == "__main__":
for parallel_mode in ["row", "column"]: for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode) test_disable_fp8_layer(parallel_mode)
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
# test_disable_fp8_gemms # test_disable_fp8_gemms
_run_test_with_combinations( _run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"] test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
...@@ -690,7 +711,10 @@ if __name__ == "__main__": ...@@ -690,7 +711,10 @@ if __name__ == "__main__":
extra_args=["column", "row"], extra_args=["column", "row"],
sample_size=20, sample_size=20,
) )
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
_run_test_with_combinations( _run_test_with_combinations(
test_per_tensor_scaling, test_per_tensor_scaling,
all_boolean, all_boolean,
......
...@@ -509,7 +509,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ...@@ -509,7 +509,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
def test_linear(): def test_linear():
"""Run linear layer tests with various configurations.""" """Run linear layer tests with various configurations."""
kwargs_list = [ base_kwargs_list = [
{}, {},
{"bias": False}, {"bias": False},
{"init_method": _constant}, {"init_method": _constant},
...@@ -519,6 +519,15 @@ def test_linear(): ...@@ -519,6 +519,15 @@ def test_linear():
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
{"save_original_input": True}, {"save_original_input": True},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue continue
...@@ -688,7 +697,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -688,7 +697,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
def test_layernorm_linear(): def test_layernorm_linear():
kwargs_list = [ base_kwargs_list = [
{}, {},
{"bias": False}, {"bias": False},
{"init_method": _constant}, {"init_method": _constant},
...@@ -699,6 +708,15 @@ def test_layernorm_linear(): ...@@ -699,6 +708,15 @@ def test_layernorm_linear():
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
for parallel_mode in ["column"]: for parallel_mode in ["column"]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
...@@ -793,7 +811,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg ...@@ -793,7 +811,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
def test_layernorm_mlp(): def test_layernorm_mlp():
kwargs_list = [ base_kwargs_list = [
{}, {},
{"init_method": _constant}, {"init_method": _constant},
{"output_layer_init_method": _constant}, {"output_layer_init_method": _constant},
...@@ -807,7 +825,15 @@ def test_layernorm_mlp(): ...@@ -807,7 +825,15 @@ def test_layernorm_mlp():
{"return_layernorm_output": True}, {"return_layernorm_output": True},
{"delay_wgrad_compute": True}, {"delay_wgrad_compute": True},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
for set_parallel_mode in [True]: for set_parallel_mode in [True]:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
...@@ -882,7 +908,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs): ...@@ -882,7 +908,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
def test_transformer_layer(): def test_transformer_layer():
kwargs_list = [ base_kwargs_list = [
{}, {},
{"num_gqa_groups": 4}, {"num_gqa_groups": 4},
{"init_method": _constant}, {"init_method": _constant},
...@@ -902,6 +928,15 @@ def test_transformer_layer(): ...@@ -902,6 +928,15 @@ def test_transformer_layer():
{"fuse_qkv_params": True}, {"fuse_qkv_params": True},
{"activation": "relu"}, {"activation": "relu"},
] ]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list: for kwargs in kwargs_list:
for sequence_parallel in [False, True]: for sequence_parallel in [False, True]:
......
...@@ -9,7 +9,8 @@ from pathlib import Path ...@@ -9,7 +9,8 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine as te
""" """
Distributed numerics tests Distributed numerics tests
...@@ -61,4 +62,15 @@ def test_distributed(quantization): ...@@ -61,4 +62,15 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", "None")
os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8)
_run_test(quantization) _run_test(quantization)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
if ori_int8_sim_fp8 is None or ori_int8_sim_fp8 == "None":
os.environ["NVTE_INT8_SIM_FP8"] = "0"
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
...@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, ...@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
} }
struct HipBlasLtUserArgsDeleter {
class userArgsManager { void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
public: hipFree(ptr);
userArgsManager() {}
~userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) {
return device_it->second;
} }
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
userArgs_map_[device_id] = userArgs;
return userArgs;
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
}; };
class d_userArgsManager { using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
public:
d_userArgsManager() {}
~d_userArgsManager() { inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
// Release all userArgs when the manager is destroyed hipblaslt_ext::UserArguments* raw_ptr = nullptr;
for (auto& device_pair : d_userArgs_map_) { if (host) {
hipFree(device_pair.second); // Only one userArgs per device NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} } else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} }
return HipBlasLtUserArgsPtr(raw_ptr);
}
// Get a userArgs for the given device (creates if necessary) inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
hipblaslt_ext::UserArguments* get(int device_id, size_t size) { thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
std::lock_guard<std::mutex> lock(mutex_); thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
// Check if the userArgs for this device exists auto size_it = user_args_cache.find(size);
auto device_it = d_userArgs_map_.find(device_id); if (size_it != user_args_cache.end()) {
if (device_it != d_userArgs_map_.end()) { return size_it->second.get();
return device_it->second;
} }
else
// Create a new userArgs for this device if it doesn't exist {
hipblaslt_ext::UserArguments* d_userArgs; HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments))); hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
// Store the userArgs in the map for this device return raw_ptr;
d_userArgs_map_[device_id] = d_userArgs;
return d_userArgs;
} }
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
d_userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
// Define a static userArgs manager
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m, std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
...@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int device_id; hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
hipGetDevice(&device_id); hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs; // hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
......
...@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr ...@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
// Choose between runtime-compiled or statically-compiled kernel // Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel //TODO:Using RTC may cause kernel crashes. Therefore, set use_rtc to true to avoid using RTC and resolve the kernel crash issue.
#ifdef USE_ROCM
const bool use_rtc = false;
#else
const bool use_rtc = true;
#endif
if (aligned && rtc::is_enabled() && use_rtc) { // Runtime-compiled tuned kernel
// Pick kernel config // Pick kernel config
std::vector<KernelConfig> kernel_configs; std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16); kernel_configs.reserve(16);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment