"tests/cpp/operator/test_multi_padding.cu" did not exist on "7f2703304dd3f90b282ea323d10f9f59b8d859fb"
Commit c37084b9 authored by yuguo's avatar yuguo
Browse files

[DCU] surpport NVTE_MOE_BATCHCOUNT

parent c686efc1
......@@ -731,6 +731,21 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan
#endif
#ifdef __HIP_PLATFORM_AMD__
static inline int getIntEnv(const char *name, int defval, int minval)
{
int val = defval;
const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0')
{
val = atoi(env);
if (val < minval)
{
val = minval;
}
}
return val;
}
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
......@@ -739,18 +754,17 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_batchgemm);
using namespace transformer_engine;
assert(num_gemms % num_batchgemm_streams == 0);
static int batch_count = num_gemms / num_batchgemm_streams;
int batch_count = getIntEnv("NVTE_MOE_BATCHCOUNT", 2, 1);;
// Inits streams and events (once, globally)
std::call_once(init_flag_batchgemm, init_streams_and_events_batchgemm);
int num_stream_used = num_batchgemm_streams;
int num_stream_used = std::min(num_batchgemm_streams, num_gemms);
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event_batchgemm[0], stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams_batchgemm[s], cublas_event_batchgemm[0]));
}
for (int i = 0; i < num_stream_used; i++) {
for (int i = 0; i < num_gemms; i++) {
nvte_cublas_batchgemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_batchgemm_streams], accumulate, use_split_accumulator, math_sm_count,
batch_count, compute_streams_batchgemm[i % num_batchgemm_streams]);
......
......@@ -4,7 +4,7 @@
"""BatchedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import os
import torch
import transformer_engine_torch as tex
......@@ -79,6 +79,8 @@ class _BatchedLinear(torch.autograd.Function):
*weights_and_biases,
) -> torch.Tensor:
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
......@@ -158,8 +160,9 @@ class _BatchedLinear(torch.autograd.Function):
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
assert weights_fp8[0].size(0) % batch_num == 0, "weights_fp8[0].size(0) should be batch_num multiply."
out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)],
[sum(m_splits), weights_fp8[0].size(0) // batch_num],
dtype=activation_dtype,
device=device,
)
......@@ -448,7 +451,9 @@ class BatchedLinear(TransformerEngineBaseModule):
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
self.batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
assert num_gemms % self.batch_num == 0, "Number of GEMMs should be batch_num multiply."
self.num_gemms = num_gemms // self.batch_num
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......@@ -464,7 +469,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0}
self._offsets = {"input": 0, "weight": self.num_gemms, "output": 2 * self.num_gemms, "grad_output": 0}
if tp_group is None:
self.tp_size = tp_size
......@@ -483,17 +488,17 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
self.in_features = divide(self.in_features * self.batch_num, self.tp_size)
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
# In batchgemm, we use batch=batch_num to launch blas batchgemm
for i in range(self.num_gemms):
# Construct weight parameter
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
self.out_features * self.batch_num,
self.in_features,
device=device,
dtype=params_dtype,
......@@ -548,15 +553,15 @@ class BatchedLinear(TransformerEngineBaseModule):
# Set parallelism attributes for linear biases
if self.use_bias:
for i in range(self.num_gemms):
for bias in self.bias_names:
if self.parallel_mode == "row":
setattr(
getattr(self, f"bias{i}"),
getattr(self, bias),
"sequence_parallel",
self.sequence_parallel,
)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1)
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
@no_torch_dynamo()
def forward(
......@@ -591,7 +596,8 @@ class BatchedLinear(TransformerEngineBaseModule):
assert not isinstance(
inp, Float8Tensor
), "BatchedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
m_splits_batch_gemm = [x * self.batch_num for x in m_splits[0:int(self.num_gemms)]]
assert len(m_splits_batch_gemm) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
......@@ -641,7 +647,7 @@ class BatchedLinear(TransformerEngineBaseModule):
args = [None]
args += (
inp,
m_splits,
m_splits_batch_gemm,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
self.fp8,
......@@ -668,7 +674,7 @@ class BatchedLinear(TransformerEngineBaseModule):
[
o + cast_if_needed(b, self.activation_dtype)
for o, b in zip(
torch.split(out.view(-1, self.out_features), m_splits), bias_tensors
torch.split(out.view(-1, self.out_features), m_splits_batch_gemm), bias_tensors
)
]
).view(out_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment