"EDK2/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "2936666a78aa6a9db8e1f489d6de6900c151cae3"
Unverified Commit 70117306 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Fixes for multi-stream cuBLAS (#1045)



* fix workspaces and unfused bias in multi-stream cuBLAS

* Expose num_streams via pybind

* Fix C-compatibility

* rm importing packaging in test_fused_attn.py

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 91a16a3f
...@@ -1261,7 +1261,9 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False ...@@ -1261,7 +1261,9 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_params): def test_grouped_linear_accuracy(
dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None
):
if fp8 and not fp8_available: if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -1276,6 +1278,7 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par ...@@ -1276,6 +1278,7 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=True,
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda", device="cuda",
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
...@@ -1285,6 +1288,7 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par ...@@ -1285,6 +1288,7 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par
4 * config.hidden_size, 4 * config.hidden_size,
bias=True, bias=True,
params_dtype=dtype, params_dtype=dtype,
parallel_mode=parallel_mode,
device="cuda", device="cuda",
).eval() ).eval()
for _ in range(num_gemms) for _ in range(num_gemms)
...@@ -1307,6 +1311,20 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par ...@@ -1307,6 +1311,20 @@ def test_grouped_linear_accuracy(dtype, num_gemms, bs, model, fp8, fp8_model_par
torch.testing.assert_close(o, o_ref, rtol=0, atol=0) torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("parallel_mode", ["column", "row"])
def test_grouped_linear_accuracy_parallel_mode(parallel_mode):
"""Split the tests to reduce CI time"""
test_grouped_linear_accuracy(
dtype=torch.float32,
num_gemms=6,
bs=2,
model=list(model_configs.keys())[0],
fp8=True,
fp8_model_params=True,
parallel_mode=parallel_mode,
)
def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
reset_rng_states() reset_rng_states()
......
...@@ -378,10 +378,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -378,10 +378,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream);
} }
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B, void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias, const NVTETensor *bias, NVTETensor *pre_gelu_out,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb, const int num_gemms, bool transa, bool transb, bool grad,
bool grad, std::vector<NVTETensor> workspace, bool accumulate, NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm); NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
...@@ -389,14 +389,14 @@ void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETe ...@@ -389,14 +389,14 @@ void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETe
// Inits streams and events (once, globally) // Inits streams and events (once, globally)
std::call_once(init_flag, init_streams_and_events); std::call_once(init_flag, init_streams_and_events);
int num_stream_used = std::min(num_streams, static_cast<int>(A.size())); int num_stream_used = std::min(num_streams, num_gemms);
// wait for current stream to finish // wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream)); NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream));
for (int s = 0; s < num_stream_used; s++) { for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0]));
} }
for (size_t i = 0; i < A.size(); i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]); compute_streams[i % num_streams]);
......
...@@ -92,6 +92,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -92,6 +92,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in,out] D List of output matrices. * \param[in,out] D List of output matrices.
* \param[in] bias List of bias tensors. * \param[in] bias List of bias tensors.
* \param[in,out] pre_gelu_out List of output matrix before GELU activation. * \param[in,out] pre_gelu_out List of output matrix before GELU activation.
* \param[in] num_gemms Number of GEMMs to compute.
* \param[in] transa Whether A matrix is transposed. * \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed. * \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the * \param[in] grad Whether this operation is part of the
...@@ -102,10 +103,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -102,10 +103,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on. * \param[in] stream CUDA stream to wait on.
*/ */
void nvte_multi_stream_cublas_gemm(std::vector<NVTETensor> A, std::vector<NVTETensor> B, void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
std::vector<NVTETensor> D, std::vector<NVTETensor> bias, const NVTETensor* bias, NVTETensor* pre_gelu_out,
std::vector<NVTETensor> pre_gelu_out, bool transa, bool transb, const int num_gemms, bool transa, bool transb, bool grad,
bool grad, std::vector<NVTETensor> workspace, bool accumulate, NVTETensor* workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream); cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -134,12 +134,15 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int ...@@ -134,12 +134,15 @@ void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int
te_pre_gelu_out.emplace_back(make_tensor( te_pre_gelu_out.emplace_back(make_tensor(
pre_gelu_out[i].data_ptr(), gelu_shape, pre_gelu_out[i].data_ptr(), gelu_shape,
GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr));
te_workspace.emplace_back(make_tensor(workspace[i % num_streams].data_ptr(), {workspaceSize}, }
DType::kByte, nullptr, nullptr, nullptr)); for (size_t i = 0; i < workspace.size(); i++) {
te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte,
nullptr, nullptr, nullptr));
} }
// For now, we only have multi-stream cublas backend. // For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_gemm(te_A, te_B, te_D, te_bias, te_pre_gelu_out, transa, transb, grad, nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_workspace, accumulate, use_split_accumulator, math_sm_count, te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
at::cuda::getCurrentCUDAStream()); te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
} }
...@@ -153,6 +153,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -153,6 +153,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version", m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
// Support THD format for Context Parallel // Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor, m.def("thd_read_half_tensor", &thd_read_half_tensor,
......
...@@ -48,7 +48,6 @@ _multi_stream_cublas_workspace = [] ...@@ -48,7 +48,6 @@ _multi_stream_cublas_workspace = []
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3 _NUM_MAX_UB_STREAMS = 3
_NUM_MAX_CUBLAS_STREAMS = 4
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -73,7 +72,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]: ...@@ -73,7 +72,7 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas.""" """Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_workspace global _multi_stream_cublas_workspace
if not _multi_stream_cublas_workspace: if not _multi_stream_cublas_workspace:
for _ in range(_NUM_MAX_CUBLAS_STREAMS): for _ in range(tex._num_cublas_streams):
_multi_stream_cublas_workspace.append( _multi_stream_cublas_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda") torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
) )
......
...@@ -829,7 +829,15 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -829,7 +829,15 @@ class GroupedLinear(TransformerEngineBaseModule):
out = linear_fn(*args) out = linear_fn(*args)
if self.gemm_bias_unfused_add: if self.gemm_bias_unfused_add:
out = [o + cast_if_needed(b, self.activation_dtype) for o, b in zip(out, bias_tensors)] out_shape = out.shape
out = torch.cat(
[
o + cast_if_needed(b, self.activation_dtype)
for o, b in zip(
torch.split(out.view(-1, self.out_features), m_splits), bias_tensors
)
]
).view(out_shape)
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment