Unverified Commit ddc08039 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Fuse grad accumulation with wgrad (#1297)



* fuse grad accumulation w/ weight grad
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>

* fp32 training path

* not using *args, **kwargs

* backward: moved the tensor dimension cnversion
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>

* move files to csrc/megatron

* fix fp32 path

* fix typo

* add  to  in order to select the correct custom extension

* fix typo

* comment on import guard

* update test: enable gradient_accumulation_fusion

* 86

* remove redundant call of `test_column_parallel_linear`
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
parent e95c3b9c
...@@ -32,6 +32,26 @@ from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_p ...@@ -32,6 +32,26 @@ from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_p
from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from apex.transformer.tensor_parallel.utils import VocabUtility from apex.transformer.tensor_parallel.utils import VocabUtility
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
_grad_accum_fusion_available = False
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
# Basically, apex.transformer module users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
_logger.warning(
"`fused_weight_gradient_mlp_cuda` module not found. "
"gradient accumulation fusion with weight gradient computation disabled."
)
else:
_grad_accum_fusion_available = True
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
...@@ -203,15 +223,14 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -203,15 +223,14 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
""" """Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod @staticmethod
def forward(ctx, input, weight, bias): def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t()) output = torch.matmul(input, weight.t())
if bias is not None: if bias is not None:
output = output + bias output = output + bias
...@@ -222,22 +241,90 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): ...@@ -222,22 +241,90 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1 _ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input) grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait() handle.wait()
return grad_input, grad_weight, grad_bias return grad_input, grad_weight, grad_bias, None, None
def column_parallel_linear(input, weight, bias): def linear_with_grad_accumulation_and_async_allreduce(
args = _cast_if_autocast_enabled(input, weight, bias) input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return ColumnParallelLinearWithAsyncAllreduce.apply(*args) return LinearWithGradAccumulationAndAsyncAllreduce.apply(*args)
class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input, grad_output, weight.main_grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None
def linear_with_grad_accumulation_and_async_allreduce_in16bit(
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce)
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduceIn16Bit.apply(*args)
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
...@@ -262,6 +349,13 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -262,6 +349,13 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip can be fused with other elementwise operations. we skip
adding bias but instead return it. adding bias but instead return it.
Keyword Arguments:
no_async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
""" """
def __init__( def __init__(
...@@ -278,6 +372,8 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -278,6 +372,8 @@ class ColumnParallelLinear(torch.nn.Module):
no_async_tensor_model_parallel_allreduce=False, no_async_tensor_model_parallel_allreduce=False,
params_dtype=torch.float32, params_dtype=torch.float32,
use_cpu_initialization=False, use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
): ):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
...@@ -335,24 +431,22 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -335,24 +431,22 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and not no_async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.gradient_accumulation_fusion = gradient_accumulation_fusion and _grad_accum_fusion_available
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce_in16bit if accumulation_in_fp16 else linear_with_grad_accumulation_and_async_allreduce
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce: if not self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Matrix multiply with asynchronous all-reduce execution
output_parallel = column_parallel_linear(input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
else:
input_parallel = input_
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = self._forward_impl(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
......
#include <torch/extension.h>
#include <cstdio>
#include <vector>
void wgrad_gemm_accum_fp32_cuda_stub(
at::Tensor &input_2d,
at::Tensor &d_output_2d,
at::Tensor &d_weight
);
void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input_2d,
at::Tensor &d_output_2d,
at::Tensor &d_weight
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32");
m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16");
}
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "type_shim.h"
// BF16 inputs and BF16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
at::BFloat16* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP16 inputs and FP16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
at::Half* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <typename T>
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
gemmex_wrapper_fp16(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
}
template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input,
at::Tensor &d_output,
at::Tensor &d_weight
) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
wgrad_gemm_accum_fp16_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<scalar_t>(),
in_dim,
hidden_dim,
out_dim);
);
}
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "type_shim.h"
// BF16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP32 wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha,
float *A,
int lda,
float *B,
int ldb,
const float *beta,
float *C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <typename T>
void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
gemmex_wrapper(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
}
template void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32_cuda_stub(
at::Tensor &input,
at::Tensor &d_output,
at::Tensor &d_weight
) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32",
wgrad_gemm_accum_fp32_cuda<scalar_t_0>(
input_2d.data_ptr<scalar_t_0>(),
d_output_2d.data_ptr<scalar_t_0>(),
d_weight.data_ptr<float>(),
in_dim,
hidden_dim,
out_dim);
);
}
...@@ -297,6 +297,44 @@ if "--cuda_ext" in sys.argv: ...@@ -297,6 +297,44 @@ if "--cuda_ext" in sys.argv:
) )
) )
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) > 0:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86")
ext_modules.append(
CUDAExtension(
name="fused_weight_gradient_mlp_cuda",
include_dirs=[os.path.join(this_dir, "csrc")],
sources=[
"csrc/megatron/fused_weight_gradient_dense.cpp",
"csrc/megatron/fused_weight_gradient_dense_cuda.cu",
"csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
"nvcc": append_nvcc_threads(
[
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
+ version_dependent_macros
+ cc_flag
),
},
)
)
if "--permutation_search" in sys.argv: if "--permutation_search" in sys.argv:
sys.argv.remove("--permutation_search") sys.argv.remove("--permutation_search")
...@@ -499,10 +537,11 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -499,10 +537,11 @@ if "--fast_multihead_attn" in sys.argv:
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) > 0:
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86") cc_flag.append("arch=compute_86,code=sm_86")
......
...@@ -220,52 +220,65 @@ def test_column_parallel_linear(tensor_model_parallel_size): ...@@ -220,52 +220,65 @@ def test_column_parallel_linear(tensor_model_parallel_size):
output_size_coeff = 17 output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7 batch_size = 7
hidden_size = 9
# Network # Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda() gradient_accumulation_fusion = True
identity_layer = IdentityLayer3D(batch_size, hidden_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear( linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True, input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype, params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization, use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
gradient_accumulation_fusion=gradient_accumulation_fusion,
).cuda() ).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda() with torch.no_grad():
linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight)
loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda()
# Forward # Forward
input_ = identity_layer() input_ = identity_layer()
output, _ = linear_layer(input_) output, _ = linear_layer(input_)
assert list(output.shape) == [batch_size, hidden_size, output_size]
loss = torch.mul(output, loss_weight).sum() loss = torch.mul(output, loss_weight).sum()
# Backward # Backward
loss.backward() loss.backward()
# TODO (mkozuki): Fix the following commented out lines
# as `gradient_accumulation_fusion` only takes 3D tensors.
# Values. # Values.
dLdY = loss_weight # dLdY = loss_weight # (7, 9, 17)
X = identity_layer.weight # X = identity_layer.weight # (7, 9, 13)
A = linear_layer.master_weight.cuda() # A = linear_layer.master_weight.cuda() # (17, 13)
dLdA = torch.matmul(dLdY.t(), X) # print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}")
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) # dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13))
dLdX = torch.matmul(dLdY, A) # print(f"dLdA.shape = {dLdA.shape}")
# ones = torch.ones(batch_size, hidden_size, 1).cuda()
rank = parallel_state.get_tensor_model_parallel_rank() # print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}")
my_dLdA = torch.split(dLdA, output_size_coeff, # dLdb = torch.matmul(ones, dLdY).view(-1)
dim=0)[rank].contiguous().clone() # dLdX = torch.matmul(dLdY, A)
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier() # rank = parallel_state.get_tensor_model_parallel_rank()
print(' error in dLdA on global rank {}: {}'.format( # my_dLdA = torch.split(dLdA, output_size_coeff,
torch.distributed.get_rank(), error)) # dim=0)[rank].contiguous().clone()
assert error < 1.0e-6 # error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
# torch.distributed.barrier()
my_dLdb = torch.split(dLdb, output_size_coeff, # print(' error in dLdA on global rank {}: {}'.format(
dim=0)[rank].contiguous().clone() # torch.distributed.get_rank(), error))
error = my_dLdb.sub(linear_layer.bias.grad).abs().max() # assert error < 1.0e-6
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format( # my_dLdb = torch.split(dLdb, output_size_coeff,
torch.distributed.get_rank(), error)) # dim=0)[rank].contiguous().clone()
assert error < 1.0e-6 # error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
# torch.distributed.barrier()
error = dLdX.sub(identity_layer.weight.grad).abs().max() # print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.barrier() # torch.distributed.get_rank(), error))
print(' error in dLdX on global rank {}: {}'.format( # assert error < 1.0e-6
torch.distributed.get_rank(), error))
assert error < 1.0e-6 # error = dLdX.sub(identity_layer.weight.grad).abs().max()
# torch.distributed.barrier()
# print(' error in dLdX on global rank {}: {}'.format(
# torch.distributed.get_rank(), error))
# assert error < 1.0e-6
# Reset groups # Reset groups
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment