Commit e32965ff authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.8' into 'main'

Fix user args core dump in mt

See merge request dcutoolkit/deeplearing/TransformerEngine!57
parents 4b65dfa3 a13c52ad
......@@ -17,6 +17,7 @@ import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import (
_emulate_linear,
......@@ -47,7 +48,6 @@ TEST_NR = 0
fp8_available, _ = FP8GlobalStateManager.is_fp8_available()
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
tp_size = WORLD_SIZE
......@@ -72,6 +72,16 @@ def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None,
def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"):
if IS_HIP_EXTENSION:
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
name=name,
bias=False,
parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
)
else:
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
......@@ -363,7 +373,6 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
......@@ -580,6 +589,9 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
if IS_HIP_EXTENSION:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0:
fake_quant_fp8_create_config(
fprop_inp,
......@@ -667,6 +679,10 @@ if __name__ == "__main__":
random.seed(SEED)
_init_distributed()
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
test_log_expert_parallel()
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
......@@ -676,6 +692,11 @@ if __name__ == "__main__":
for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
# test_disable_fp8_gemms
_run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
......@@ -690,7 +711,10 @@ if __name__ == "__main__":
extra_args=["column", "row"],
sample_size=20,
)
if IS_HIP_EXTENSION:
# Output type 32 (FP32) does not support int8 simulation.
pass
else:
_run_test_with_combinations(
test_per_tensor_scaling,
all_boolean,
......
......@@ -509,7 +509,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
def test_linear():
"""Run linear layer tests with various configurations."""
kwargs_list = [
base_kwargs_list = [
{},
{"bias": False},
{"init_method": _constant},
......@@ -519,6 +519,15 @@ def test_linear():
{"delay_wgrad_compute": True},
{"save_original_input": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
......@@ -688,7 +697,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
def test_layernorm_linear():
kwargs_list = [
base_kwargs_list = [
{},
{"bias": False},
{"init_method": _constant},
......@@ -699,6 +708,15 @@ def test_layernorm_linear():
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
......@@ -793,7 +811,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
def test_layernorm_mlp():
kwargs_list = [
base_kwargs_list = [
{},
{"init_method": _constant},
{"output_layer_init_method": _constant},
......@@ -807,7 +825,15 @@ def test_layernorm_mlp():
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
for set_parallel_mode in [True]:
for sequence_parallel in [False, True]:
......@@ -882,7 +908,7 @@ def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
def test_transformer_layer():
kwargs_list = [
base_kwargs_list = [
{},
{"num_gqa_groups": 4},
{"init_method": _constant},
......@@ -902,6 +928,15 @@ def test_transformer_layer():
{"fuse_qkv_params": True},
{"activation": "relu"},
]
#TODO:The blockwise recipe does not currently support calculations with bias set to true.
"""
For AMD platforms, when the quantization recipe is fp8_block_scaling, iterate through base_kwargs_list,
and if the bias value is not set in kwargs or the bias value is true, set bias to false.
"""
if IS_HIP_EXTENSION and QUANTIZATION == "fp8_block_scaling":
kwargs_list = [kwargs for kwargs in base_kwargs_list if kwargs.get("bias", True) is False]
else:
kwargs_list = base_kwargs_list
for kwargs in kwargs_list:
for sequence_parallel in [False, True]:
......
......@@ -9,7 +9,8 @@ from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine as te
"""
Distributed numerics tests
......@@ -61,4 +62,15 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", "None")
os.environ["NVTE_INT8_SIM_FP8"] = "1"
importlib.reload(te.pytorch.fp8)
_run_test(quantization)
if IS_HIP_EXTENSION and quantization == "fp8_block_scaling":
if ori_int8_sim_fp8 is None or ori_int8_sim_fp8 == "None":
os.environ["NVTE_INT8_SIM_FP8"] = "0"
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
......@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
class userArgsManager {
public:
userArgsManager() {}
~userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) {
return device_it->second;
struct HipBlasLtUserArgsDeleter {
void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
hipFree(ptr);
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
userArgs_map_[device_id] = userArgs;
return userArgs;
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
class d_userArgsManager {
public:
d_userArgsManager() {}
using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
~d_userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : d_userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if (host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
return HipBlasLtUserArgsPtr(raw_ptr);
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = d_userArgs_map_.find(device_id);
if (device_it != d_userArgs_map_.end()) {
return device_it->second;
inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
auto size_it = user_args_cache.find(size);
if (size_it != user_args_cache.end()) {
return size_it->second.get();
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
d_userArgs_map_[device_id] = d_userArgs;
return d_userArgs;
else
{
HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
return raw_ptr;
}
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
d_userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
// Define a static userArgs manager
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
......@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int device_id;
hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
......
......@@ -217,7 +217,13 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
// Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
//TODO:Using RTC may cause kernel crashes. Therefore, set use_rtc to true to avoid using RTC and resolve the kernel crash issue.
#ifdef USE_ROCM
const bool use_rtc = false;
#else
const bool use_rtc = true;
#endif
if (aligned && rtc::is_enabled() && use_rtc) { // Runtime-compiled tuned kernel
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment