Commit 9e6e1871 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.3' into 'main'

Develop v2.3

See merge request dcutoolkit/deeplearing/TransformerEngine!9
parents 9815d228 460b006c
......@@ -496,7 +496,7 @@ def _train(opts):
if opts.benchmark:
# Warmup to not profile CPU overhead
for _ in range(20):
for _ in range(opts.benchmark_iter):
if opts.use_cuda_graphs:
test_graph.replay()
else:
......
......@@ -171,7 +171,7 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()
def _test_batched_linear_accuracy(
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
):
reset_rng_states()
if fp8:
......@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
)
loss = out.sum()
loss.backward()
if delay_wgrad_compute:
if isinstance(block, BatchedLinear):
block.backward_dw()
else:
for i in range(num_gemms):
block[i].backward_dw()
torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
if isinstance(block, BatchedLinear):
if getattr(p, "main_grad", None) is not None:
for j in range(batch_num):
outputs.append(p.main_grad[p.main_grad.shape[0] // batch_num * j : p.main_grad.shape[0] // batch_num * (j + 1)])
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
for j in range(batch_num):
outputs.append(p.grad[p.grad.shape[0] // batch_num * j : p.grad.shape[0] // batch_num * (j + 1)])
else:
if getattr(p, "main_grad", None) is not None:
outputs.append(p.main_grad)
assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True
else:
outputs.append(p.grad)
return outputs
@pytest.mark.parametrize("dtype", param_types)
......@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
@pytest.mark.parametrize("recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_batched_linear_accuracy(
dtype,
num_gemms,
......@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
delay_wgrad_compute,
parallel_mode=None,
):
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
......@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
parallel_mode=parallel_mode,
device="cuda",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
delay_wgrad_compute=delay_wgrad_compute,
).eval()
sequential_linear = torch.nn.ModuleList(
[
......@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
sequential_linear[i * batch_num + j].weight.main_grad = weight_i.main_grad[weight_i.main_grad.shape[0] // batch_num * j : weight_i.main_grad.shape[0] // batch_num * (j + 1)].clone()
outputs_ref = _test_batched_linear_accuracy(
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
)
outputs = _test_batched_linear_accuracy(
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation
batched_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation, delay_wgrad_compute, batch_num
)
# Shoule be bit-wise match
......@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=6e-3, atol=6e-3)
if __name__ == "__main__":
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True)
test_batched_linear_accuracy(torch.float32, 2, 1, "126m", False, recipe.Float8CurrentScaling(), True, True, True)
......@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority = gemm_priority;
_comm_priority = comm_priority;
}
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority));
_stream_compute.push_back(std::move(stream));
if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
}
_stream_compute.push_back(compute_streams[i]);
}
_num_splits = num_splits;
......@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
_ub_stream_nums = num_max_streams;
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
......@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
static cudaStream_t comm_stream;
if (comm_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority));
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
}
_stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
}
......@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::bulk_overlap
......@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0);
}
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
......@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
......@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
......@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::split_overlap_rs
/***************************************************************************************************
......@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
_ub_stream_nums = num_max_streams;
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
......@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
static cudaStream_t send_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
static cudaStream_t recv_stream;
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream));
if (send_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&send_streams[i], cudaStreamNonBlocking, _comm_priority));
}
_stream_send.push_back(send_streams[i]);
}
if (recv_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
cudaStreamCreateWithPriority(&recv_stream, cudaStreamNonBlocking, _comm_priority));
}
_stream_recv = recv_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
}
......@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < num_steps - 1) {
// P2P communication
......@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < _tp_size - 1) {
// P2P communication
......@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapP2PBase::split_overlap_ag
/*
......@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id);
}
if (i > 0) {
// P2P communication chunk
......@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
......@@ -15,11 +15,7 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace transformer_engine {
......@@ -141,6 +137,7 @@ class CommOverlapCore {
class CommOverlapBase : public CommOverlapCore {
protected:
int _ub_stream_nums;
int _rs_kernel_type;
bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm;
......@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore {
class CommOverlapP2PBase : public CommOverlapCore {
protected:
int _ub_stream_nums;
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
bool _aggregate;
......
......@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_dummy_wgrads = {}
multi_stream_cublas_batchgemm_workspace = []
_multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
ub_stream_nums = int(os.getenv("NVTE_UB_STREAM_NUMS", "2"))
_NUM_MAX_UB_STREAMS = ub_stream_nums if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []
......
......@@ -6,7 +6,7 @@
import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
import functools
import torch
import transformer_engine_torch as tex
......@@ -18,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import WeightGradStore
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
from ..utils import (
divide,
......@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
wgrad_store: WeightGradStore,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
......@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.wgrad_store = wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
......@@ -246,26 +249,30 @@ class _BatchLinear(torch.autograd.Function):
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
# WGRAD
_, grad_biases, _ = batchgemm(
inputmats,
grad_output_mats,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
batched_gemm_wgrad = functools.partial(
batchgemm,
dtype=ctx.activation_dtype,
workspaces=get_multi_stream_cublas_batchgemm_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
)
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmats, grad_output_mats, wgrad_list], batched_gemm_wgrad)
else:
_, grad_biases_, _ = batched_gemm_wgrad(inputmats, grad_output_mats, wgrad_list)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor
clear_tensor_data(*inputmats)
clear_tensor_data(*inputmats_t)
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
def handle_custom_ddp_from_mcore(w, wgrad):
if w.requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
......@@ -293,6 +300,18 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
else:
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
and ctx.wgrad_store.delay_wgrad_compute()
and not ctx.fp8
):
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
......@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
None, # is_first_microbatch
None, # fp8
None, # fp8_calibration
None, # wgrad_store
None, # fp8_meta
None, # fuse_wgrad_accumulation
None, # cpu_offloading
......@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
"""
def __init__(
......@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False,
) -> None:
super().__init__()
......@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.wgrad_store = WeightGradStore(delay_wgrad_compute)
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
_GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT = 0, self.num_gemms, 2 * self.num_gemms
......@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fp8_meta,
self.fuse_wgrad_accumulation,
CPUOffloadEnabled,
......@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range("_GroupedLinear_wgrad"):
(_, grad_biases_, _), tensor_list = self.wgrad_store.pop()
wgrad_list = tensor_list[2]
if not self.fuse_wgrad_accumulation:
for i in range(self.num_gemms):
weight_param = getattr(self, f"weight{i}")
if weight_param.grad is None:
weight_param.grad = wgrad_list[i].to(weight_param.dtype)
if self.use_bias:
for i in range(self.num_gemms):
bias_param = getattr(self, f"bias{i}")
if bias_param.grad is None:
bias_param.grad = grad_biases_[i].to(bias_param.dtype)
del grad_biases_
del wgrad_list
del tensor_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment