Commit c37084b9 authored by yuguo's avatar yuguo
Browse files

[DCU] surpport NVTE_MOE_BATCHCOUNT

parent c686efc1
...@@ -731,6 +731,21 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan ...@@ -731,6 +731,21 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan
#endif #endif
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static inline int getIntEnv(const char *name, int defval, int minval)
{
int val = defval;
const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0')
{
val = atoi(env);
if (val < minval)
{
val = minval;
}
}
return val;
}
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad, const int num_gemms, bool transa, bool transb, bool grad,
...@@ -739,18 +754,17 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B ...@@ -739,18 +754,17 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm); NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
using namespace transformer_engine; using namespace transformer_engine;
assert(num_gemms % num_batchgemm_streams == 0); int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);;
static int batch_count = num_gemms / num_batchgemm_streams;
// Inits streams and events (once, globally) // Inits streams and events (once, globally)
std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm); std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);
int num_stream_used = num_batchgemm_streams; int num_stream_used = std::min(num_batchgemm_streams, num_gemms);
// wait for current stream to finish // wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream)); NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
for (int s = 0; s < num_stream_used; s++) { for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
} }
for (int i = 0; i < num_stream_used; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]); batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""BatchedLinear API""" """BatchedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -79,6 +79,8 @@ class _BatchedLinear(torch.autograd.Function): ...@@ -79,6 +79,8 @@ class _BatchedLinear(torch.autograd.Function):
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
num_gemms = len(m_splits) num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
...@@ -158,8 +160,9 @@ class _BatchedLinear(torch.autograd.Function): ...@@ -158,8 +160,9 @@ class _BatchedLinear(torch.autograd.Function):
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
assert weights_fp8[0].size(0) % batch_num == 0, "weights_fp8[0].size(0) should be batch_num multiply."
out = torch.empty( out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)], [sum(m_splits), weights_fp8[0].size(0) // batch_num],
dtype=activation_dtype, dtype=activation_dtype,
device=device, device=device,
) )
...@@ -448,7 +451,9 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -448,7 +451,9 @@ class BatchedLinear(TransformerEngineBaseModule):
super().__init__() super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms self.batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
assert num_gemms % self.batch_num == 0, "Number of GEMMs should be batch_num multiply."
self.num_gemms = num_gemms // self.batch_num
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
...@@ -464,7 +469,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -464,7 +469,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0} self._offsets = {"input": 0, "weight": self.num_gemms, "output": 2 * self.num_gemms, "grad_output": 0}
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -483,17 +488,17 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -483,17 +488,17 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.parallel_mode == "column": if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size) self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row": elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size) self.in_features = divide(self.in_features * self.batch_num, self.tp_size)
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
# In batchgemm, we use batch=batch_num to launch blas batchgemm
for i in range(self.num_gemms): for i in range(self.num_gemms):
# Construct weight parameter # Construct weight parameter
self.register_parameter( self.register_parameter(
f"weight{i}", f"weight{i}",
torch.nn.Parameter( torch.nn.Parameter(
torch.empty( torch.empty(
self.out_features, self.out_features * self.batch_num,
self.in_features, self.in_features,
device=device, device=device,
dtype=params_dtype, dtype=params_dtype,
...@@ -548,15 +553,15 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -548,15 +553,15 @@ class BatchedLinear(TransformerEngineBaseModule):
# Set parallelism attributes for linear biases # Set parallelism attributes for linear biases
if self.use_bias: if self.use_bias:
for i in range(self.num_gemms): for bias in self.bias_names:
if self.parallel_mode == "row": if self.parallel_mode == "row":
setattr( setattr(
getattr(self, f"bias{i}"), getattr(self, bias),
"sequence_parallel", "sequence_parallel",
self.sequence_parallel, self.sequence_parallel,
) )
elif self.parallel_mode == "column": elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
@no_torch_dynamo() @no_torch_dynamo()
def forward( def forward(
...@@ -591,7 +596,8 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -591,7 +596,8 @@ class BatchedLinear(TransformerEngineBaseModule):
assert not isinstance( assert not isinstance(
inp, Float8Tensor inp, Float8Tensor
), "BatchedLinear doesn't support input tensor in FP8." ), "BatchedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." m_splits_batch_gemm = [x * self.batch_num for x in m_splits[0:int(self.num_gemms)]]
assert len(m_splits_batch_gemm) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
...@@ -641,7 +647,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -641,7 +647,7 @@ class BatchedLinear(TransformerEngineBaseModule):
args = [None] args = [None]
args += ( args += (
inp, inp,
m_splits, m_splits_batch_gemm,
self.apply_bias and not self.gemm_bias_unfused_add, self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -668,7 +674,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -668,7 +674,7 @@ class BatchedLinear(TransformerEngineBaseModule):
[ [
o + cast_if_needed(b, self.activation_dtype) o + cast_if_needed(b, self.activation_dtype)
for o, b in zip( for o, b in zip(
torch.split(out.view(-1, self.out_features), m_splits), bias_tensors torch.split(out.view(-1, self.out_features), m_splits_batch_gemm), bias_tensors
) )
] ]
).view(out_shape) ).view(out_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment