Commit 9e6e1871 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.3' into 'main'

Develop v2.3

See merge request dcutoolkit/deeplearing/TransformerEngine!9
parents 9815d228 460b006c
...@@ -496,7 +496,7 @@ def _train(opts): ...@@ -496,7 +496,7 @@ def _train(opts):
if opts.benchmark: if opts.benchmark:
# Warmup to not profile CPU overhead # Warmup to not profile CPU overhead
for _ in range(20): for _ in range(opts.benchmark_iter):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
test_graph.replay() test_graph.replay()
else: else:
......
...@@ -171,7 +171,7 @@ def reset_global_fp8_state(): ...@@ -171,7 +171,7 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset() FP8GlobalStateManager.reset()
def _test_batched_linear_accuracy( def _test_batched_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
): ):
reset_rng_states() reset_rng_states()
if fp8: if fp8:
...@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy( ...@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
) )
loss = out.sum() loss = out.sum()
loss.backward() loss.backward()
if delay_wgrad_compute:
if isinstance(block, BatchedLinear):
block.backward_dw()
else:
for i in range(num_gemms):
block[i].backward_dw()
torch.cuda.synchronize() torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad] outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
if isinstance(block, BatchedLinear):
if getattr(p, "main_grad", None) is not None:
for j in range(batch_num):
outputs.append(p.main_grad[p.main_grad.shape[0] // batch_num * j : p.main_grad.shape[0] // batch_num * (j + 1)])
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
for j in range(batch_num):
outputs.append(p.grad[p.grad.shape[0] // batch_num * j : p.grad.shape[0] // batch_num * (j + 1)])
else:
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs return outputs
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
...@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy( ...@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_batched_linear_accuracy( def test_batched_linear_accuracy(
dtype, dtype,
num_gemms, num_gemms,
...@@ -224,6 +247,7 @@ def test_batched_linear_accuracy( ...@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
recipe, recipe,
fp8_model_params, fp8_model_params,
fuse_wgrad_accumulation, fuse_wgrad_accumulation,
delay_wgrad_compute,
parallel_mode=None, parallel_mode=None,
): ):
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2")) batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
...@@ -250,6 +274,7 @@ def test_batched_linear_accuracy( ...@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device="cuda", device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation, fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
).eval() ).eval()
sequential_linear = torch.nn.ModuleList( sequential_linear = torch.nn.ModuleList(
[ [
...@@ -281,10 +306,10 @@ def test_batched_linear_accuracy( ...@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone() sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()
outputs_ref = _test_batched_linear_accuracy( outputs_ref = _test_batched_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
) )
outputs = _test_batched_linear_accuracy( outputs = _test_batched_linear_accuracy(
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
) )
# Shoule be bit-wise match # Shoule be bit-wise match
...@@ -292,4 +317,4 @@ def test_batched_linear_accuracy( ...@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3) torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)
if __name__ == "__main__": if __name__ == "__main__":
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True) test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True, True)
...@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority = gemm_priority; _gemm_priority = gemm_priority;
_comm_priority = comm_priority; _comm_priority = comm_priority;
} }
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream; if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
_stream_compute.push_back(std::move(stream)); }
_stream_compute.push_back(compute_streams[i]);
} }
_num_splits = num_splits; _num_splits = num_splits;
...@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) { atomic_gemm) {
_ub_stream_nums = num_max_streams;
_rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0); _rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
...@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA( static cudaStream_t comm_stream;
cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority)); if (comm_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
}
_stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
} }
...@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::bulk_overlap } // CommOverlapBase::bulk_overlap
...@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0); use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0);
}
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
...@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk = get_tensor_chunk( workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
accumulate, use_split_accumulator, _math_sms, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
...@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
...@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::split_overlap_rs } // CommOverlapBase::split_overlap_rs
/*************************************************************************************************** /***************************************************************************************************
...@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) { atomic_gemm) {
_ub_stream_nums = num_max_streams;
_is_p2p = true; _is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS; _is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate; _aggregate = aggregate;
...@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
} }
static cudaStream_t send_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
static cudaStream_t recv_stream;
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
cudaStream_t stream; if (send_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&send_streams[i], cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream)); }
_stream_send.push_back(send_streams[i]);
}
if (recv_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&recv_stream, cudaStreamNonBlocking, _comm_priority));
} }
NVTE_CHECK_CUDA( _stream_recv = recv_stream;
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
} }
...@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < num_steps - 1) { if (i < num_steps - 1) {
// P2P communication // P2P communication
...@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < _tp_size - 1) { if (i < _tp_size - 1) {
// P2P communication // P2P communication
...@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapP2PBase::split_overlap_ag } // CommOverlapP2PBase::split_overlap_ag
/* /*
...@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto workspace_chunk = auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id); pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id);
}
if (i > 0) { if (i > 0) {
// P2P communication chunk // P2P communication chunk
...@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
...@@ -15,11 +15,7 @@ ...@@ -15,11 +15,7 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" #include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 #define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace transformer_engine { namespace transformer_engine {
...@@ -141,6 +137,7 @@ class CommOverlapCore { ...@@ -141,6 +137,7 @@ class CommOverlapCore {
class CommOverlapBase : public CommOverlapCore { class CommOverlapBase : public CommOverlapCore {
protected: protected:
int _ub_stream_nums;
int _rs_kernel_type; int _rs_kernel_type;
bool _rs_overlap_first_gemm; bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm; cudaStream_t _stream_comm;
...@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore { ...@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore {
class CommOverlapP2PBase : public CommOverlapCore { class CommOverlapP2PBase : public CommOverlapCore {
protected: protected:
int _ub_stream_nums;
bool _is_reduce_scatter{false}; bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false}; bool _use_multiatomic_ag{false};
bool _aggregate; bool _aggregate;
......
...@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True ...@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = [] _multi_stream_cublas_workspace = []
_dummy_wgrads = {} _dummy_wgrads = {}
multi_stream_cublas_batchgemm_workspace = [] _multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3 ub_stream_nums = int(os.getenv("NVTE_UB_STREAM_NUMS", "2"))
_NUM_MAX_UB_STREAMS = ub_stream_nums if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import os import os
import logging import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union, List from typing import Any, Callable, Dict, Optional, Tuple, Union, List
import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
...@@ -18,6 +18,7 @@ from .base import ( ...@@ -18,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import WeightGradStore
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
...@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
fp8_calibration: bool, fp8_calibration: bool,
wgrad_store: WeightGradStore,
fp8_meta: Dict[str, Any], fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
cpu_offloading: bool, cpu_offloading: bool,
...@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.wgrad_store = wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1]) return out.view(-1, *inp.shape[1:-1], out.shape[-1])
...@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function): ...@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights for w in weights
] ]
# WGRAD batched_gemm_wgrad = functools.partial(
_, grad_biases, _ = batchgemm( batchgemm,
inputmats, dtype=ctx.activation_dtype,
grad_output_mats, workspaces=get_multi_stream_cublas_batchgemm_workspace(),
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
) )
# WGRAD
# Deallocate input tensor if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(*inputmats) ctx.wgrad_store.put([inputmats, grad_output_mats, wgrad_list], batched_gemm_wgrad)
clear_tensor_data(*inputmats_t) else:
_, grad_biases_, _ = batched_gemm_wgrad(inputmats, grad_output_mats, wgrad_list)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms for i in range(ctx.num_gemms):
if grad_biases[i] is None:
def handle_custom_ddp_from_mcore(w, wgrad): grad_biases[i] = grad_biases_[i]
if w.requires_grad: del grad_biases_
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True # Deallocate input tensor
if getattr(w, "zero_out_wgrad", False): clear_tensor_data(*inputmats)
wgrad = torch.zeros( clear_tensor_data(*inputmats_t)
w.main_grad.shape,
dtype=w.dtype, def handle_custom_ddp_from_mcore(w, wgrad):
device=torch.cuda.current_device(), if w.requires_grad:
requires_grad=False, if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
) w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else: else:
wgrad = torch.empty( wgrad = None
w.main_grad.shape, return wgrad
dtype=w.dtype,
device=torch.cuda.current_device(), wgrad_list = [
requires_grad=False, handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
) ]
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else: else:
wgrad = None wgrad_list = [None] * ctx.num_gemms
return wgrad
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
wgrad_list = [ if not ctx.use_bias or (
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list) ctx.wgrad_store is not None
] and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
):
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
...@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
None, # is_first_microbatch None, # is_first_microbatch
None, # fp8 None, # fp8
None, # fp8_calibration None, # fp8_calibration
None, # wgrad_store
None, # fp8_meta None, # fp8_meta
None, # fuse_wgrad_accumulation None, # fuse_wgrad_accumulation
None, # cpu_offloading None, # cpu_offloading
...@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
""" """
def __init__( def __init__(
...@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms
...@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
self.wgrad_store,
self.fp8_meta, self.fp8_meta,
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
CPUOffloadEnabled, CPUOffloadEnabled,
...@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.return_bias: if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors] return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out return out
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_
del wgrad_list
del tensor_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment