Unverified Commit ddc08039 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] Fuse grad accumulation with wgrad (#1297)



* fuse grad accumulation w/ weight grad
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>

* fp32 training path

* not using *args, **kwargs

* backward: moved the tensor dimension cnversion
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>

* move files to csrc/megatron

* fix fp32 path

* fix typo

* add  to  in order to select the correct custom extension

* fix typo

* comment on import guard

* update test: enable gradient_accumulation_fusion

* 86

* remove redundant call of `test_column_parallel_linear`
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
parent e95c3b9c
......@@ -32,6 +32,26 @@ from apex.transformer.tensor_parallel.mappings import reduce_from_tensor_model_p
from apex.transformer.tensor_parallel.mappings import scatter_to_tensor_model_parallel_region
from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker
from apex.transformer.tensor_parallel.utils import VocabUtility
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
_grad_accum_fusion_available = False
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
# Basically, apex.transformer module users are expected to install APEX's
# `--cpp_ext` and `--cuda_ext`. The example installation command is as follows:
# `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ."
# at the root of APEX repository.
_logger.warning(
"`fused_weight_gradient_mlp_cuda` module not found. "
"gradient accumulation fusion with weight gradient computation disabled."
)
else:
_grad_accum_fusion_available = True
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
......@@ -203,15 +223,66 @@ class VocabParallelEmbedding(torch.nn.Module):
return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(input, grad_output, weight.main_grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce)
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncAllreduce.apply(*args)
class LinearWithGradAccumulationAndAsyncAllreduceIn16Bit(torch.autograd.Function):
"""Linear layer execution with asynchronous all-reduce and gradient accumulation fusion in backprop."""
@staticmethod
def forward(ctx, input, weight, bias):
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
......@@ -222,22 +293,38 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
input = input.view(input.shape[0] * input.shape[1], input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(input, grad_output, weight.main_grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None
def column_parallel_linear(input, weight, bias):
args = _cast_if_autocast_enabled(input, weight, bias)
def linear_with_grad_accumulation_and_async_allreduce_in16bit(
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
):
args = _cast_if_autocast_enabled(input, weight, bias, gradient_accumulation_fusion, async_grad_allreduce)
with torch.cuda.amp.autocast(enabled=False):
return ColumnParallelLinearWithAsyncAllreduce.apply(*args)
return LinearWithGradAccumulationAndAsyncAllreduceIn16Bit.apply(*args)
class ColumnParallelLinear(torch.nn.Module):
......@@ -262,6 +349,13 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
Keyword Arguments:
no_async_tensor_model_parallel_allreduce:
params_dtype:
use_cpu_initialization:
gradient_accumulation_fusion:
accumulation_in_fp16:
"""
def __init__(
......@@ -278,6 +372,8 @@ class ColumnParallelLinear(torch.nn.Module):
no_async_tensor_model_parallel_allreduce=False,
params_dtype=torch.float32,
use_cpu_initialization=False,
gradient_accumulation_fusion=False,
accumulation_in_fp16: bool = False,
):
super(ColumnParallelLinear, self).__init__()
......@@ -335,24 +431,22 @@ class ColumnParallelLinear(torch.nn.Module):
self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and
world_size > 1)
self.gradient_accumulation_fusion = gradient_accumulation_fusion and _grad_accum_fusion_available
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce_in16bit if accumulation_in_fp16 else linear_with_grad_accumulation_and_async_allreduce
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Matrix multiply with asynchronous all-reduce execution
output_parallel = column_parallel_linear(input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
if not self.async_tensor_model_parallel_allreduce:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
else:
input_parallel = input_
# Matrix multiply.
output_parallel = self._forward_impl(
input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
......
#include <torch/extension.h>
#include <cstdio>
#include <vector>
void wgrad_gemm_accum_fp32_cuda_stub(
at::Tensor &input_2d,
at::Tensor &d_output_2d,
at::Tensor &d_weight
);
void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input_2d,
at::Tensor &d_output_2d,
at::Tensor &d_weight
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32");
m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16");
}
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "type_shim.h"
// BF16 inputs and BF16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
at::BFloat16* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_16BF,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP16 inputs and FP16 accumulation
void gemmex_wrapper_fp16(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
at::Half* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <typename T>
void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
gemmex_wrapper_fp16(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
}
template void wgrad_gemm_accum_fp16_cuda<at::Half>(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp16_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp16_cuda_stub(
at::Tensor &input,
at::Tensor &d_output,
at::Tensor &d_weight
) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16",
wgrad_gemm_accum_fp16_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<scalar_t>(),
in_dim,
hidden_dim,
out_dim);
);
}
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "type_shim.h"
// BF16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP16 Tensor core wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
float* C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// FP32 wrapper around cublas GEMMEx
void gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha,
float *A,
int lda,
float *B,
int ldb,
const float *beta,
float *C,
int ldc) {
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
template <typename T>
void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
gemmex_wrapper(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
}
template void wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template void wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32_cuda_stub(
at::Tensor &input,
at::Tensor &d_output,
at::Tensor &d_weight
) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
const int hidden_dim = input_2d.size(0);
const int in_dim = input_2d.size(1);
const int out_dim = d_weight.size(0);
DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32",
wgrad_gemm_accum_fp32_cuda<scalar_t_0>(
input_2d.data_ptr<scalar_t_0>(),
d_output_2d.data_ptr<scalar_t_0>(),
d_weight.data_ptr<float>(),
in_dim,
hidden_dim,
out_dim);
);
}
......@@ -297,6 +297,44 @@ if "--cuda_ext" in sys.argv:
)
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) > 0:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86")
ext_modules.append(
CUDAExtension(
name="fused_weight_gradient_mlp_cuda",
include_dirs=[os.path.join(this_dir, "csrc")],
sources=[
"csrc/megatron/fused_weight_gradient_dense.cpp",
"csrc/megatron/fused_weight_gradient_dense_cuda.cu",
"csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu",
],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
"nvcc": append_nvcc_threads(
[
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
+ version_dependent_macros
+ cc_flag
),
},
)
)
if "--permutation_search" in sys.argv:
sys.argv.remove("--permutation_search")
......@@ -499,12 +537,13 @@ if "--fast_multihead_attn" in sys.argv:
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86")
if int(bare_metal_minor) > 0:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86")
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
......
......@@ -220,52 +220,65 @@ def test_column_parallel_linear(tensor_model_parallel_size):
output_size_coeff = 17
output_size = output_size_coeff * tensor_model_parallel_size
batch_size = 7
hidden_size = 9
# Network
identity_layer = IdentityLayer2D(batch_size, input_size).cuda()
gradient_accumulation_fusion = True
identity_layer = IdentityLayer3D(batch_size, hidden_size, input_size).cuda()
linear_layer = layers.ColumnParallelLinear(
input_size, output_size, keep_master_weight_for_test=True,
params_dtype=global_vars.get_args().params_dtype,
use_cpu_initialization=global_vars.get_args().use_cpu_initialization,
gradient_accumulation_fusion=gradient_accumulation_fusion,
).cuda()
loss_weight = torch.randn([batch_size, output_size]).cuda()
with torch.no_grad():
linear_layer.weight.main_grad = torch.randn_like(linear_layer.weight)
loss_weight = torch.randn([batch_size, hidden_size, output_size]).cuda()
# Forward
input_ = identity_layer()
output, _ = linear_layer(input_)
assert list(output.shape) == [batch_size, hidden_size, output_size]
loss = torch.mul(output, loss_weight).sum()
# Backward
loss.backward()
# TODO (mkozuki): Fix the following commented out lines
# as `gradient_accumulation_fusion` only takes 3D tensors.
# Values.
dLdY = loss_weight
X = identity_layer.weight
A = linear_layer.master_weight.cuda()
dLdA = torch.matmul(dLdY.t(), X)
dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1)
dLdX = torch.matmul(dLdY, A)
rank = parallel_state.get_tensor_model_parallel_rank()
my_dLdA = torch.split(dLdA, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdA on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
my_dLdb = torch.split(dLdb, output_size_coeff,
dim=0)[rank].contiguous().clone()
error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdb on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
error = dLdX.sub(identity_layer.weight.grad).abs().max()
torch.distributed.barrier()
print(' error in dLdX on global rank {}: {}'.format(
torch.distributed.get_rank(), error))
assert error < 1.0e-6
# dLdY = loss_weight # (7, 9, 17)
# X = identity_layer.weight # (7, 9, 13)
# A = linear_layer.master_weight.cuda() # (17, 13)
# print(f"dLdY.shape, X.shape, A.shape = {dLdY.shape, X.shape, A.shape}")
# dLdA = torch.matmul(dLdY.view(-1, 17).t(), X.view(-1, 13))
# print(f"dLdA.shape = {dLdA.shape}")
# ones = torch.ones(batch_size, hidden_size, 1).cuda()
# print(f"dLdY.shape, ones.shape = {dLdY.shape, ones.shape}")
# dLdb = torch.matmul(ones, dLdY).view(-1)
# dLdX = torch.matmul(dLdY, A)
# rank = parallel_state.get_tensor_model_parallel_rank()
# my_dLdA = torch.split(dLdA, output_size_coeff,
# dim=0)[rank].contiguous().clone()
# error = my_dLdA.sub(linear_layer.weight.grad).abs().max()
# torch.distributed.barrier()
# print(' error in dLdA on global rank {}: {}'.format(
# torch.distributed.get_rank(), error))
# assert error < 1.0e-6
# my_dLdb = torch.split(dLdb, output_size_coeff,
# dim=0)[rank].contiguous().clone()
# error = my_dLdb.sub(linear_layer.bias.grad).abs().max()
# torch.distributed.barrier()
# print(' error in dLdb on global rank {}: {}'.format(
# torch.distributed.get_rank(), error))
# assert error < 1.0e-6
# error = dLdX.sub(identity_layer.weight.grad).abs().max()
# torch.distributed.barrier()
# print(' error in dLdX on global rank {}: {}'.format(
# torch.distributed.get_rank(), error))
# assert error < 1.0e-6
# Reset groups
parallel_state.destroy_model_parallel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment