"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "8aa68ed5c444dcce7ff9cb2875943bb7f0257e6b"
Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import fast_multihead_attn
class FastEncdecAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
(
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs,
) = fast_multihead_attn.encdec_multihead_attn_norm_add_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
ctx.save_for_backward(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_q_grads,
input_kv_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
) = fast_multihead_attn.encdec_multihead_attn_norm_add_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t[0],
)
# import pdb; pdb.set_trace()
return (
None,
None,
None,
input_q_grads,
input_kv_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
None,
None,
)
fast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply
import torch
import fast_multihead_attn
class FastSelfAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
pad_mask,
mask_additive,
dropout_prob,
):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
mask_additive_t = torch.tensor([mask_additive])
if use_biases_t[0]:
if not mask_additive:
(
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.self_attn_bias_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn_bias.forward() \
ctx.save_for_backward(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
softmax_results,
null_tensor,
null_tensor,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
else:
(
input_lin_results,
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.self_attn_bias_additive_mask_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn_bias_additive_mask.forward( \
ctx.save_for_backward(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
null_tensor,
bmm1_results,
pad_mask,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
else:
(
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.self_attn_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn.forward( \
ctx.save_for_backward(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
softmax_results,
null_tensor,
null_tensor,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
softmax_results,
bmm1_results,
pad_mask,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
if use_biases_t[0]:
if not mask_additive_t[0]:
(
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
) = fast_multihead_attn.self_attn_bias_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn_bias.backward( \
else:
(
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
) = fast_multihead_attn.self_attn_bias_additive_mask_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn_bias_additive_mask.backward( \
else:
input_bias_grads = None
output_bias_grads = None
input_grads, input_weight_grads, output_weight_grads = fast_multihead_attn.self_attn_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn.backward( \
return (
None,
None,
None,
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
None,
None,
None,
)
fast_self_attn_func = FastSelfAttnFunc.apply
import torch
import fast_multihead_attn
class FastSelfAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
(
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs,
) = fast_multihead_attn.self_attn_norm_add_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn_norm_add.forward( \
ctx.save_for_backward(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads,
) = fast_multihead_attn.self_attn_norm_add_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn_norm_add.backward( \
return (
None,
None,
None,
input_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads,
None,
None,
)
fast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply
import torch
import fast_multihead_attn
class MaskSoftmaxDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive])
if mask_additive:
dropout_results, dropout_mask, softmax_results = fast_multihead_attn.additive_mask_softmax_dropout_forward(
use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob
)
# fast_additive_mask_softmax_dropout.forward( \
else:
dropout_results, dropout_mask, softmax_results = fast_multihead_attn.mask_softmax_dropout_forward(
use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob
)
# fast_mask_softmax_dropout.forward( \
ctx.save_for_backward(
use_mask_t,
heads_t,
softmax_results,
dropout_mask,
pad_mask if use_mask else null_tensor,
mask_additive_t,
dropout_prob_t,
)
return dropout_results.detach()
@staticmethod
def backward(ctx, output_grads):
(
use_mask_t,
heads_t,
softmax_results,
dropout_mask,
pad_mask,
mask_additive_t,
dropout_prob_t,
) = ctx.saved_tensors
if mask_additive_t[0]:
input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward(
use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, dropout_prob_t[0]
)
# fast_additive_mask_softmax_dropout.backward( \
else:
input_grads = fast_multihead_attn.mask_softmax_dropout_backward(
use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, pad_mask, dropout_prob_t[0]
)
# fast_mask_softmax_dropout.backward( \
return None, None, input_grads, None, None, None
fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply
import math
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
class SelfMultiheadAttn(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=False,
include_norm_add=False,
impl="fast",
separate_qkv_params=False,
mask_additive=False,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.bias = bias
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim ** -0.5
self.separate_qkv_params = separate_qkv_params
self.mask_additive = mask_additive
if mask_additive:
assert self.include_norm_add == False, "additive mask not supported with layer norm"
assert impl == "default" or (
impl == "fast" and bias
), "additive mask not supported for fast mode without bias"
if separate_qkv_params:
self.q_weight = Parameter(torch.empty(embed_dim, embed_dim))
self.k_weight = Parameter(torch.empty(embed_dim, embed_dim))
self.v_weight = Parameter(torch.empty(embed_dim, embed_dim))
else:
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
if self.bias:
if separate_qkv_params:
self.q_bias = Parameter(torch.empty(embed_dim))
self.k_bias = Parameter(torch.empty(embed_dim))
self.v_bias = Parameter(torch.empty(embed_dim))
else:
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
self.out_proj_bias = Parameter(torch.empty(embed_dim))
else:
if separate_qkv_params:
self.register_parameter("q_bias", None)
self.register_parameter("k_bias", None)
self.register_parameter("v_bias", None)
self.q_bias = None
self.k_bias = None
self.v_bias = None
else:
self.register_parameter("in_proj_bias", None)
self.in_proj_bias = None
self.register_parameter("out_proj_bias", None)
self.out_proj_bias = None
if self.include_norm_add:
if impl == "fast":
self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter("lyr_norm_gamma_weights", None)
self.register_parameter("lyr_norm_beta_weights", None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
if impl == "fast":
self.attn_func = fast_self_attn_norm_add_func
elif impl == "default":
self.attn_func = self_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
else:
if impl == "fast":
self.attn_func = fast_self_attn_func
elif impl == "default":
self.attn_func = self_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
if self.separate_qkv_params:
nn.init.xavier_uniform_(self.q_weight)
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
# in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
if self.separate_qkv_params:
nn.init.constant_(self.q_bias, 0.0)
nn.init.constant_(self.k_bias, 0.0)
nn.init.constant_(self.v_bias, 0.0)
else:
nn.init.constant_(self.in_proj_bias, 0.0)
nn.init.constant_(self.out_proj_bias, 0.0)
if self.include_norm_add:
if self.impl == "fast":
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else:
self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if self.separate_qkv_params:
input_weights = (
torch.cat(
[
self.q_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),
self.k_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),
self.v_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),
],
dim=1,
)
.reshape(3 * self.embed_dim, self.embed_dim)
.contiguous()
)
else:
input_weights = self.in_proj_weight
if self.bias:
if self.separate_qkv_params:
input_bias = (
torch.cat(
[
self.q_bias.view(self.num_heads, 1, self.head_dim),
self.k_bias.view(self.num_heads, 1, self.head_dim),
self.v_bias.view(self.num_heads, 1, self.head_dim),
],
dim=1,
)
.reshape(3 * self.embed_dim)
.contiguous()
)
else:
input_bias = self.in_proj_bias
else:
input_bias = None
if key_padding_mask is not None:
assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None:
assert self.mask_additive == False, "additive mask not supported for time mask"
mask = attn_mask
else:
mask = None
if self.include_norm_add:
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
self.lyr_nrm_gamma_weights,
self.lyr_nrm_beta_weights,
input_weights,
self.out_proj_weight,
mask,
self.dropout,
)
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
lyr_nrm_results,
input_weights,
self.out_proj_weight,
input_bias,
self.out_proj_bias,
mask,
self.mask_additive,
self.dropout,
)
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else:
outputs = outputs + query
else:
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
input_weights,
self.out_proj_weight,
input_bias,
self.out_proj_bias,
mask,
self.mask_additive,
self.dropout,
)
else:
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
query,
input_weights,
self.out_proj_weight,
input_bias,
self.out_proj_bias,
mask,
self.mask_additive,
self.dropout,
)
return outputs, None
import torch
import torch.nn.functional as F
class SelfAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
use_time_mask,
is_training,
heads,
scale,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
mask,
is_additive_mask,
dropout_prob,
):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads
# Input Linear GEMM
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim*3]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)
if use_biases_t[0]:
input_lin_results = torch.addmm(
input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
input_lin_results = torch.mm(
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0, 1)
)
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads, 3, head_dim)
queries = input_lin_results[:, :, 0, :]
keys = input_lin_results[:, :, 1, :]
values = input_lin_results[:, :, 2, :]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty(
(queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype, device=torch.device("cuda")
)
matmul1_results = torch.baddbmm(
matmul1_results,
queries.transpose(0, 1),
keys.transpose(0, 1).transpose(1, 2),
out=matmul1_results,
beta=0.0,
alpha=scale_t[0],
)
if mask is not None:
# Self Attention Time Mask
if use_time_mask:
assert len(mask.size()) == 2, "Timing mask is not 2D!"
assert mask.size(0) == mask.size(1), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float("-inf"))
# Key Padding Mask
else:
batches, seql_q, seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
if is_additive_mask:
matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
else:
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf"))
matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training:
dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1.0 - dropout_prob_t[0]))
else:
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
# this requires a backward pass specified.
# Input1: from_softmax [seqs*heads, seql_q, seql_k]
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty(
(dropout_results.size(1), dropout_results.size(0), values.size(2)),
dtype=dropout_results.dtype,
device=torch.device("cuda"),
).transpose(1, 0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results)
matmul2_results = (
matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
)
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0]:
outputs = torch.addmm(
output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
outputs = torch.mm(
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0, 1)
)
outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))
ctx.save_for_backward(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0]
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads_t[0], 3, head_dim)
queries = input_lin_results[:, :, 0, :]
keys = input_lin_results[:, :, 1, :]
values = input_lin_results[:, :, 2, :]
# Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_results_grads = torch.empty_like(input_lin_results)
queries_grads = input_lin_results_grads[:, :, 0, :]
keys_grads = input_lin_results_grads[:, :, 1, :]
values_grads = input_lin_results_grads[:, :, 2, :]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights
)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)),
)
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1) * heads_t[0], head_dim).transpose(0, 1)
if use_biases_t[0]:
output_bias_grads = torch.sum(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0
)
else:
output_bias_grads = None
# Matmul2 - DGRAD1
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(
queries_grads.transpose(0, 1),
softmax_grads,
keys.transpose(0, 1),
out=queries_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(
keys_grads.transpose(0, 1),
softmax_grads.transpose(1, 2),
queries.transpose(0, 1),
out=keys_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Input Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
input_lin_results_grads = input_lin_results_grads.view(
inputs.size(0) * inputs.size(1), heads_t[0] * 3 * head_dim
)
input_grads = torch.mm(input_lin_results_grads, input_weights)
input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))
# Input Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, 3*embed_dim(3072)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [3*embed_dim, embed_dim]
# GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
input_weight_grads = torch.mm(
input_lin_results_grads.transpose(0, 1), inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2))
)
if use_biases_t[0]:
input_bias_grads = torch.sum(input_lin_results_grads, 0)
else:
input_bias_grads = None
return (
None,
None,
None,
None,
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
None,
None,
)
self_attn_func = SelfAttnFunc.apply
from .fp16_optimizer import FP16_Optimizer
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
import collections
import contextlib
import enum
import importlib
import inspect
import io
import math
import threading
import torch
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank
def _round_to_multiple(number, multiple, round_up=True):
"""Assumes arguments are positive integers"""
return (number+multiple-1 if round_up else number) // multiple * multiple
class DistributedFusedAdam(torch.optim.Optimizer):
"""AdamW optimizer with ZeRO algorithm.
Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
This implements the ZeRO-2 algorithm, which distributes the
optimizer state and gradients between parallel processes. In
particular, the parameters are flattened, grouped into fixed-size
buckets, and the optimizer state for each bucket is sharded over
the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic
Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts
defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for
computing running averages of gradient and its square.
(default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty)
(default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad
variant of this algorithm from the paper
`On the Convergence of Adam and Beyond`_ (default: False).
This is not yet supported.
dtype (torch.dtype, optional): datatype for optimizer state
(default: torch.float32)
grad_sync_dtype (torch.dtype, optional): datatype for gradient
synchronization (default: same as dtype)
param_sync_dtype (torch.dtype, optional): datatype for
parameter synchronization (default: same as dtype)
device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU with one GPU
per process.
process_group (torch.distributed.ProcessGroup, optional):
parallel processes participating in optimizer (default:
default group in torch.distributed). This group is
interpreted as a 2D grid with dimensions
distributed_size x redundant_size.
distributed_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to distribute optimizer
state over (default: same as process_group)
redundant_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to replicate optimizer state
over (default: group only containing calling process)
average_grad_sync (bool, optional): whether to use average
reduction for gradient synchronization rather than sum
(default: True)
overlap_grad_sync(boolean, optional): whether to overlap
gradient synchronization with backward pass compute
(default: True)
bucket_cap_mb (float, optional): bucket size in megabytes
(default: 100)
pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2)
contiguous_grad_buffer (bool, optional): allocate gradient
buckets out of a large persistent buffer (default: False).
This allows individual parameter gradients to be accessed
externally (see grad_buffer_view function). It also
maximizes memory usage and may prevent overlapping
communication and compute.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054
"""
class ParameterFragment:
"""Buffer ranges for a parameter fragment
Describes corresponding regions in parameter buffer and
parameter bucket.
"""
def __init__(
self,
param_group_id,
param_id,
bucket_id,
param_range,
bucket_range,
in_local_shard,
shard_range,
shard_bucket_range,
shard_param_range,
):
# Parameter group index
self.param_group_id = param_group_id
# Parameter index within parameter group
self.param_id = param_id
# Bucket index
self.bucket_id = bucket_id
# Range within flattened parameter buffer
self.param_range = param_range
# Range within bucket
self.bucket_range = bucket_range
# Whether fragment is in local shard of bucket
self.in_local_shard = in_local_shard
# Range within local shard
self.shard_range = shard_range
# Range of local fragment shard within bucket
self.shard_bucket_range = shard_bucket_range
# Range of local fragment shard within parameter
self.shard_param_range = shard_param_range
class StateBucket:
def __init__(self, shard_size, dtype, device):
"""Optimizer state for a bucket"""
# Buffer ranges corresponding to parameter fragments
self.fragments = []
# Local shard of parameters
self.params_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of first moment estimate
self.exp_avg_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of second moment estimate
self.exp_avg_sq_shard = torch.zeros([shard_size], dtype=dtype, device=device)
class GradientStatus(enum.Enum):
"""Status of gradients within a bucket"""
# Gradients are ready to use
READY = enum.auto()
# Bucket is partially filled with unreduced gradients
PARTIALLY_FILLED = enum.auto()
# Bucket is fully filled with unreduced gradients
FULLY_FILLED = enum.auto()
# Asynchronous reduction is in progress
SYNCING = enum.auto()
class GradientBucket:
"""Gradient buffers and state for a bucket"""
def __init__(self):
# Local shard of gradients
self.grads_shard = None
# Local contribution to gradients
self.grads_bucket = None
# Buffer for gradient reduce-scatter
self.sync_grads_shard = None
# Status of gradients
self.status = DistributedFusedAdam.GradientStatus.READY
# Request object for asynchronous communication
self.sync_request = None
def sync_wait(self):
"""Wait for asynchronous communication to finish"""
if self.sync_request is not None:
self.sync_request.wait()
self.sync_request = None
_step_supports_amp_scaling = True
def __init__(self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.,
amsgrad=False,
dtype=torch.float32,
grad_sync_dtype=None,
param_sync_dtype=None,
device='cuda',
process_group=None,
distributed_process_group=None,
redundant_process_group=None,
average_grad_sync=True,
overlap_grad_sync=True,
bucket_cap_mb=100,
pipeline_size=2,
contiguous_grad_buffer=False,
):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
super(DistributedFusedAdam, self).__init__(params, defaults)
# Adam options
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
# Datatype options
if grad_sync_dtype is None:
grad_sync_dtype = dtype
if param_sync_dtype is None:
param_sync_dtype = dtype
supported_dtypes = [
(torch.float32, torch.float16),
(torch.float32, torch.float32),
]
if (dtype, grad_sync_dtype) not in supported_dtypes:
raise RuntimeError(
'Invalid dtypes for DistributedFusedAdam '
f'(dtype={dtype}, '
f'grad_sync_dtype={grad_sync_dtype}, '
f'param_sync_dtype={param_sync_dtype}))')
if device != 'cuda':
raise RuntimeError('DistributedFusedAdam only supports GPU')
self.dtype = dtype
self.grad_sync_dtype = grad_sync_dtype
self.param_sync_dtype = param_sync_dtype
self.device = device
# Process groups
self.process_group = (
_get_default_group()
if process_group is None
else process_group
)
self.distributed_process_group = (
self.process_group
if distributed_process_group is None
else distributed_process_group
)
self.redundant_process_group = redundant_process_group
self.process_group_size = torch.distributed.get_world_size(self.process_group)
self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group)
self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group)
self.redundant_size = (
1
if self.redundant_process_group is None
else torch.distributed.get_world_size(self.redundant_process_group)
)
if self.process_group_size != self.distributed_size * self.redundant_size:
raise RuntimeError(
'Invalid process group configuration '
f'(process group size = {self.process_group_size}, '
f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})'
)
try:
self._process_group_ranks = [
_get_global_rank(self.process_group, local_rank)
for local_rank in range(self.distributed_size)
]
except:
self._process_group_ranks = list(range(self.distributed_size))
# Use average reduction for grad sync
self.average_grad_sync = average_grad_sync
# Copy param grads to bucket as soon as available
self.greedy_grad_copy = True
# Synchronize grad buckets as soon as all grads are available
self.overlap_grad_sync = overlap_grad_sync
# Number of buckets to synchronize at a time
self.pipeline_size = pipeline_size
# Allocate contiguous buffer for gradients
self.contiguous_grad_buffer = contiguous_grad_buffer
# Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
self.alignment = 128 // dtype_size
bucket_size = 1024*1024*bucket_cap_mb / dtype_size
shard_size = int(bucket_size / self.distributed_size)
shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False)
shard_size = max(shard_size, self.alignment)
bucket_size = shard_size * self.distributed_size
self.bucket_size = bucket_size
self.shard_size = shard_size
# Load CUDA kernels
global fused_adam_cuda, distributed_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_adam_cuda = importlib.import_module("distributed_adam_cuda")
# Optimizer state
self.state['buckets'] = []
self.state['step'] = 0
# Objects for gradient synchronization
self._grads_buckets = collections.defaultdict(self.GradientBucket)
self._grads_generated = set()
self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
# Divide gradients by factor before optimizer step. Used for
# grad clipping and gradient scaler.
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
# Norm of parameter gradients. Used for gradient clipping and
# gradient scaler.
self._grad_norm = None
# Check if collectives have no_copy option
self._reduce_scatter_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
)
self._all_gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
)
self._gather_no_copy = (
'no_copy' in inspect.getfullargspec(torch.distributed.gather).args
)
# Attach hooks for gradient synchronization
self._register_post_backward_hooks()
def _register_post_backward_hooks(self):
"""Attach hooks for gradient synchronization
Optimizer state for parameters are initialized lazily as they
are encountered in the backward pass.
"""
self._num_grads = 0
grad_buffer_size = 0
self._lock = threading.Lock()
self._grad_accs = []
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
torch.distributed.broadcast(
param,
src=self._process_group_ranks[0],
group=self.process_group,
)
if param.requires_grad:
self._num_grads += 1
# Callback after gradient is generated
def wrapper(p, p_group_id, p_id):
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
def reduction_hook(*unused):
with self._lock:
if 'fragments' not in self.state[p]:
self._init_param_state(p, p_group_id, p_id)
if self.greedy_grad_copy:
self._grad_copy(p)
if self.overlap_grad_sync:
self._try_start_bucket_grad_sync(
params=[p],
ignore_last_bucket=True,
)
grad_acc.register_hook(reduction_hook)
self._grad_accs.append(grad_acc)
wrapper(param, param_group_id, param_id)
# Gradient size, with padding for alignment
grad_size = _round_to_multiple(param.numel(), self.alignment)
grad_buffer_size += grad_size
# Allocate contiguous gradient buffer if needed
if self.contiguous_grad_buffer:
grad_buffer_size = _round_to_multiple(
grad_buffer_size,
self.bucket_size,
)
self._grad_buffer = torch.zeros(
[grad_buffer_size],
dtype=self.dtype,
device=self.device,
)
def init_params(self, params=None):
"""Initialize optimizer state for parameters
Arguments:
params (iterable, optional): parameters to initialize
(default: all parameters)
"""
# Default cases
if isinstance(params, torch.Tensor):
params = [params]
elif params is None:
params = []
for group in self.param_groups:
params.extend(group['params'])
# Get indices corresponding to parameters
id_map = dict()
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
id_map[param] = (param_group_id, param_id)
# Initialize parameters
for param in params:
if param in id_map and 'fragments' not in self.state[param]:
param_group_id, param_id = id_map[param]
self._init_param_state(param, param_group_id, param_id)
def _init_param_state(
self,
param,
param_group_id,
param_id,
):
"""Initialize optimizer state for a parameter"""
# Make sure there is at least one bucket
if not self.state['buckets']:
self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
# Split parameter values into fragments
# Note: Each fragment resides within a bucket
param_start = 0
param_size = param.numel()
self.state[param]['fragments'] = []
while param_start < param_size:
# Get current bucket
bucket_id = len(self.state['buckets']) - 1
bucket = self.state['buckets'][bucket_id]
fragment_id = len(bucket.fragments)
# Determine fragment position within bucket
if fragment_id == 0:
bucket_start = 0
else:
_, bucket_start = bucket.fragments[-1].bucket_range
bucket_start = _round_to_multiple(bucket_start, self.alignment)
fragment_size = min(param_size-param_start, self.bucket_size-bucket_start)
param_end = param_start + fragment_size
bucket_end = bucket_start + fragment_size
# Create new bucket if current one is full
if fragment_size <= 0:
self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
continue
# Fragment position within local shard
shard_id = self.distributed_rank
shard_start = bucket_start - self.shard_size*shard_id
shard_end = bucket_end - self.shard_size*shard_id
shard_start = min(max(shard_start, 0), self.shard_size)
shard_end = min(max(shard_end, 0), self.shard_size)
in_local_shard = shard_start < shard_end
if in_local_shard:
shard_bucket_start = shard_start + self.shard_size*shard_id
shard_bucket_end = shard_bucket_start + shard_end - shard_start
shard_param_start = shard_bucket_start - bucket_start + param_start
shard_param_end = shard_param_start + shard_end - shard_start
else:
shard_bucket_start, shard_bucket_end = None, None
shard_param_start, shard_param_end = None, None
# Record fragment info
fragment = self.ParameterFragment(
param_group_id=param_group_id,
param_id=param_id,
bucket_id=bucket_id,
param_range=(param_start,param_end),
bucket_range=(bucket_start,bucket_end),
in_local_shard=in_local_shard,
shard_range=(shard_start,shard_end),
shard_bucket_range=(shard_bucket_start,shard_bucket_end),
shard_param_range=(shard_param_start,shard_param_end),
)
self.state[param]['fragments'].append(fragment)
bucket.fragments.append(fragment)
param_start = param_end
# Initialize master param buffer
for fragment in self.state[param]['fragments']:
if fragment.in_local_shard:
bucket = self.state['buckets'][fragment.bucket_id]
param_start, param_end = fragment.shard_param_range
shard_start, shard_end = fragment.shard_range
model_param_fragment = param.view(-1)[param_start:param_end]
master_param_fragment = bucket.params_shard[shard_start:shard_end]
master_param_fragment.copy_(model_param_fragment)
def zero_grad(self, set_to_none=True):
"""Clear parameter gradients"""
# Reset bucket buffers
self._grads_buckets.clear()
# Construct views into contiguous grad buffer, if needed
if self.contiguous_grad_buffer:
self._grad_buffer.zero_()
for bucket_id in range(len(self.state['buckets'])):
bucket_start = bucket_id * self.bucket_size
bucket_end = bucket_start + self.bucket_size
bucket = self._grads_buckets[bucket_id]
bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end]
# Reset param grads
for group in self.param_groups:
for param in group['params']:
if param.grad is None or set_to_none:
param.grad = None
else:
param.grad.zero_()
# Reset other state
self._grads_generated = set()
self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
self._grad_norm = None
def _grad_copy(self, param):
"""Copy parameter gradients to buckets"""
# Copy param grad to buckets
for fragment in self.state[param]['fragments']:
# Get fragment position
bucket_id = fragment.bucket_id
bucket = self._grads_buckets[bucket_id]
grad_start, grad_end = fragment.param_range
bucket_start, bucket_end = fragment.bucket_range
# Set reduction status
if bucket.status == self.GradientStatus.SYNCING:
self._finish_bucket_grad_sync()
bucket.status = self.GradientStatus.PARTIALLY_FILLED
# Allocate gradient buffer if needed
if bucket.grads_bucket is None:
if self.contiguous_grad_buffer:
grad_buffer_start = bucket_id * self.bucket_size
grad_buffer_end = grad_buffer_start + self.bucket_size
bucket.grads_bucket = self._grad_buffer[grad_buffer_start:grad_buffer_end]
else:
bucket.grads_bucket = torch.empty(
[self.bucket_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
bucket.grads_bucket.zero_()
# Copy param grad to bucket
if param.grad is not None:
grad_in = param.grad.detach().view(-1)[grad_start:grad_end]
grad_out = bucket.grads_bucket[bucket_start:bucket_end]
if grad_in.data_ptr() != grad_out.data_ptr():
grad_out.add_(grad_in)
# Free param grad buffer
param.grad = None
def grad_buffer_view(self, param):
"""Construct view into grad buffer corresponding to param
Assumes optimizer is using a contiguous grad buffer.
"""
assert self.contiguous_grad_buffer
# Figure out corresponding position in grad buffer
param_fragments = self.state[param]['fragments']
start_bucket_id = param_fragments[0].bucket_id
start_bucket_offset, _ = param_fragments[0].bucket_range
end_bucket_id = param_fragments[-1].bucket_id
_, end_bucket_offset = param_fragments[-1].bucket_range
buffer_start = start_bucket_id * self.bucket_size + start_bucket_offset
buffer_end = end_bucket_id * self.bucket_size + end_bucket_offset
# Construct view into grad buffer
flat_buffer = self._grad_buffer[buffer_start:buffer_end]
return flat_buffer.detach().view(param.size())
def _force_bucket_grad_sync(self):
"""Ensure that all gradient buckets are synchronized"""
# Synchronize all unsynchronized buckets
self._finish_bucket_grad_sync()
buckets = [
bucket
for bucket_id, bucket in sorted(self._grads_buckets.items())
if bucket.status != self.GradientStatus.READY
]
if buckets:
self._start_bucket_grad_sync(buckets)
self._finish_bucket_grad_sync()
# Fill any unsynchronized gradients with zeros
for bucket_id in range(len(self.state['buckets'])):
bucket = self._grads_buckets[bucket_id]
if bucket.grads_shard is None:
bucket.grads_shard = torch.zeros(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
# Reset set of generated gradients
self._grads_generated = set()
def _try_start_bucket_grad_sync(
self,
params=[],
ignore_last_bucket=True,
):
"""Launches gradient synchronization if enough buckets are ready
Gradient synchronization is asynchronous. Launches gradient
synchronization if all gradients have been generated or if
there are enough buckets ready to fill pipeline.
Arguments:
params (iterable): parameters that have had their
gradients copied to buckets
ignore_last_bucket (bool): avoid synchronizing last bucket
until all gradients have been generated. This avoids
excessive synchronization when initializing buckets in
the first backward pass.
"""
# Register params that have generated grads
for param in params:
self._grads_generated.add(param)
for fragment in self.state[param]['fragments']:
bucket_id = fragment.bucket_id
bucket_fragments = self.state['buckets'][bucket_id].fragments
is_filled = True
for other_fragment in reversed(bucket_fragments):
param_group_id = other_fragment.param_group_id
param_id = other_fragment.param_id
other_param = self.param_groups[param_group_id]['params'][param_id]
if other_param not in self._grads_generated:
is_filled = False
break
if is_filled:
bucket = self._grads_buckets[bucket_id]
bucket.status = self.GradientStatus.FULLY_FILLED
# Launch reductions if enough buckets are ready
if len(self._grads_generated) == self._num_grads:
self._force_bucket_grad_sync()
else:
filled_buckets = []
for bucket_id, bucket in sorted(self._grads_buckets.items()):
if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1:
continue
if bucket.status == self.GradientStatus.FULLY_FILLED:
filled_buckets.append(bucket)
pipeline_size = _round_to_multiple(
len(filled_buckets),
self.pipeline_size,
)
if pipeline_size > 0:
self._start_bucket_grad_sync(filled_buckets[:pipeline_size])
def _start_bucket_grad_sync(self, buckets):
"""Synchronize gradient buckets
Gradient synchronization is asynchronous. Involves
reduce-scatter over distributed process group and allreduce
over redundant process group.
"""
# Call recursively if more buckets than streams
while len(buckets) > self.pipeline_size:
self._start_bucket_grad_sync(buckets[:self.pipeline_size])
buckets = buckets[self.pipeline_size:]
self._finish_bucket_grad_sync()
# Reduction operation
if self.average_grad_sync:
reduce_op = torch.distributed.ReduceOp.AVG
else:
reduce_op = torch.distributed.ReduceOp.SUM
# Reduce gradients
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(main_stream)
for i, bucket in enumerate(buckets):
bucket.status = self.GradientStatus.SYNCING
stream = self._pipeline_streams[i % self.pipeline_size]
with torch.cuda.stream(stream):
# Reduce-scatter over distributed process group
bucket.sync_wait()
if self.distributed_size == 1:
bucket.sync_grads_shard = bucket.grads_bucket
else:
with torch.cuda.stream(main_stream):
bucket.sync_grads_shard = torch.zeros(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
grads_bucket_shards = [
bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size)
]
if self._reduce_scatter_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
bucket.sync_request = (
torch.distributed.reduce_scatter(
bucket.sync_grads_shard,
grads_bucket_shards,
op=reduce_op,
group=self.distributed_process_group,
async_op=True,
**no_copy_kwarg,
)
)
# All-reduce over redundant process group
# Note: Assuming reduce-scatters are finished in the
# order they are submitted, all-reduces should be
# submitted in a consistent order. There could be race
# conditions if wait doesn't finish in order.
if self.redundant_size > 1:
bucket.sync_wait()
bucket.sync_request = (
torch.distributed.all_reduce(
bucket.sync_grads_shard,
op=reduce_op,
group=self.redundant_process_group,
async_op=True,
)
)
def _finish_bucket_grad_sync(self):
"""Wait for any gradient synchronizations that are in progress"""
for bucket_id, bucket in sorted(self._grads_buckets.items()):
if bucket.status == self.GradientStatus.SYNCING:
# Finish asynchronous communication
bucket.sync_wait()
# Accumulate gradient in local shard
if bucket.grads_shard is None:
bucket.grads_shard = bucket.sync_grads_shard
else:
bucket.grads_shard.add_(bucket.sync_grads_shard)
bucket.grads_bucket = None
bucket.sync_grads_shard = None
# Reset status
bucket.status = self.GradientStatus.READY
# Cached gradient norm has been invalidated
self._grad_norm = None
@contextlib.contextmanager
def no_sync(self, greedy_grad_copy=False):
"""Disable overlapped gradient synchronization
Context manager that is similar to
torch.nn.parallel.DistributedDataParallel.no_sync. The
gradients can be synchronized by calling grad_sync or step. If
overlapped gradient synchronization is enabled, gradients can
also be synchronized by leaving the context and performing a
backward pass.
Arguments:
greedy_grad_copy (bool, optional): copy parameter
gradients to buckets as soon as they are generated
(default: False)
"""
old_greedy_grad_copy = self.greedy_grad_copy
old_overlap_grad_sync = self.overlap_grad_sync
self.greedy_grad_copy = greedy_grad_copy
self.overlap_grad_sync = False
try:
yield
finally:
self.greedy_grad_copy = old_greedy_grad_copy
self.overlap_grad_sync = old_overlap_grad_sync
def grad_sync(self):
"""Ensure that all gradients are synchronized"""
for bucket in self.state['buckets']:
for fragment in bucket.fragments:
param_group_id = fragment.param_group_id
param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id]
if param.grad is not None:
self._grad_copy(param)
self._try_start_bucket_grad_sync(
params=[param],
ignore_last_bucket=False,
)
self._force_bucket_grad_sync()
def _local_grad_norm(self, parameters=[], norm_type=2.0):
"""Local contribution to parameter gradient norm
Returns square of 2-norm. Other norms are not yet supported.
If no parameters are provided, the norm is computed for all
parameters in optimizer. Provided parameters are assumed to be
in optimizer.
"""
norm_type = float(norm_type)
assert norm_type == 2.0
# Make sure that gradients have been reduced
self.grad_sync()
if not parameters or len(parameters) == self._num_grads:
# Compute norm of all local gradients
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[[bucket.grads_shard for bucket in self._grads_buckets.values()]],
False,
)[0] ** 2
else:
# Compute norm of selected local gradients
grads = []
for param in parameters:
for fragment in self.state[param]['fragments']:
if fragment.in_local_shard:
bucket = self._grads_buckets[fragment.bucket_id]
shard_start, shard_end = fragment.shard_range
grads.append(bucket.grads_shard[shard_start:shard_end])
if grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False,
)[0] ** 2
else:
grad_norm_sq = torch.zeros([1], dtype=torch.float32, device=self.device)
return grad_norm_sq.detach().view([])
def grad_norm(self, parameters=[], norm_type=2.0, force=False):
"""Gradient norm of parameters in optimizer
The norm is computed over all gradients together, as if they
were concatenated into a single vector. All provided
parameters must be managed by optimizer.
The computed value is cached to avoid redundant communication.
Arguments:
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2). Only 2-norm is currently
supported.
force (bool, optional): ignore cached value and force norm
computation (default: False).
"""
if force or self._grad_norm is None:
norm_type = float(norm_type)
assert norm_type == 2.0
grad_norm_sq = self._local_grad_norm(
parameters=parameters,
norm_type=norm_type,
)
torch.distributed.all_reduce(
grad_norm_sq,
op=torch.distributed.ReduceOp.SUM,
group=self.distributed_process_group,
)
self._grad_norm = grad_norm_sq.sqrt()
return self._grad_norm.detach()
def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0):
"""Clips gradient norm of parameters in optimizer
The norm is computed over all gradients together, as if they
were concatenated into a single vector. The scaling is
deferred until the optimizer step, which should be called
immediately after this function.
The computed grad norm is cached to avoid redundant
communication.
Arguments:
max_norm (float or int): max norm of the gradients
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2)
"""
assert max_norm > 0
total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type)
inv_clip_coef = (total_norm + 1e-6) / max_norm
self._inv_grad_scale = torch.clamp(inv_clip_coef, min=1.0).view(1)
return total_norm
def step(self, closure=None, *, grad_scaler=None):
"""Apply Adam optimizer step
Arguments:
closure (callable, optional): closure to recompute loss
(default: None)
grad_scaler (torch.cuda.amp.GradScaler, optional):
gradient scaler (default: None)
"""
# Apply closure
loss = None
if closure is not None:
loss = closure()
# Make sure that gradients have been reduced
self.grad_sync()
# Apply gradient scaler if provided
# Note: We compute gradient norm to check for non-finite
# values. This is more conservative and compute intensive than
# directly checking, but it avoids extra communication if we
# have already computed gradient norm e.g. for gradient
# clipping.
if grad_scaler is not None:
grad_norm = self.grad_norm()
found_inf = torch.logical_not(torch.isfinite(grad_norm))
scaler_state = grad_scaler._per_optimizer_states[id(self)]
scaler_state['found_inf_per_device'] = {found_inf.device: found_inf.float()}
if found_inf.item():
return
else:
assert grad_scaler._scale is not None
self._inv_grad_scale *= grad_scaler._scale
inv_grad_scale = self._inv_grad_scale.item()
# Construct workspace buffers
params_bucket_buffers = [
torch.empty(
[self.bucket_size],
dtype=self.param_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
if self.grad_sync_dtype == self.param_sync_dtype:
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_copy_buffers = [
params_bucket[shard_start:shard_end]
for params_bucket in params_bucket_buffers
]
else:
params_copy_buffers = [
torch.empty(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Apply optimizer step to each bucket and synchronize params
self.state['step'] += 1
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(main_stream)
for bucket_id in range(len(self.state['buckets'])):
stream_id = bucket_id % self.pipeline_size
# Bucket buffers
fragments = self.state['buckets'][bucket_id].fragments
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_bucket = params_bucket_buffers[stream_id]
params_bucket_shard = params_bucket[shard_start:shard_end]
params_shard = self.state['buckets'][bucket_id].params_shard
params_copy = params_copy_buffers[stream_id]
exp_avg = self.state['buckets'][bucket_id].exp_avg_shard
exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard
grads = self._grads_buckets[bucket_id].grads_shard
# Perform compute on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream):
# Find param fragments in local shard
buffers = collections.defaultdict(list) # p, m, v, g, p_copy
for fragment in fragments:
if fragment.in_local_shard:
param_group_id = fragment.param_group_id
shard_start, shard_end = fragment.shard_range
buffers[param_group_id].append([
params_shard[shard_start:shard_end],
exp_avg[shard_start:shard_end],
exp_avg_sq[shard_start:shard_end],
grads[shard_start:shard_end],
params_copy[shard_start:shard_end],
])
# Fuse param fragments if possible
if len(buffers) == 1:
group_id = list(buffers.keys())[0]
buffers[group_id] = [(
params_shard,
exp_avg,
exp_avg_sq,
grads,
params_copy,
)]
# Apply optimizer step to each param group
for group_id, group_buffers in buffers.items():
# Get param group configs
group = self.param_groups[group_id]
beta1, beta2 = group['betas']
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
# Copy param group configs to GPU
num_fragments = len(group_buffers)
beta1 = torch.full([num_fragments], beta1, dtype=self.dtype, device='cuda')
beta2 = torch.full([num_fragments], beta2, dtype=self.dtype, device='cuda')
bias_correction = torch.full([num_fragments], bias_correction, dtype=torch.int32, device='cuda')
eps = torch.full([num_fragments], eps, dtype=self.dtype, device='cuda')
weight_decay = torch.full([num_fragments], weight_decay, dtype=self.dtype, device='cuda')
# Apply Adam step
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
distributed_adam_cuda.multi_tensor_fused_adam,
dummy_overflow_buf,
list(zip(*group_buffers)),
beta1,
beta2,
bias_correction,
eps,
weight_decay,
group['lr'],
inv_grad_scale,
self.state['step'],
1, # Set to 0 to apply eps inside sqrt
)
# Cast parameter dtype if needed
if params_copy.data_ptr() != params_bucket_shard.data_ptr():
params_bucket_shard.copy_(params_copy)
# Allgather updated parameters
if self.distributed_size > 1:
all_params_bucket_shards = [
params_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size)
]
if self._all_gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.all_gather(
all_params_bucket_shards,
params_bucket_shard,
group=self.distributed_process_group,
**no_copy_kwarg,
)
# Copy values to param buffers
buffers = collections.defaultdict(list) # param_in, param_out
for fragment in fragments:
param_group_id = fragment.param_group_id
param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id]
bucket_start, bucket_end = fragment.bucket_range
param_start, param_end = fragment.param_range
param_in = params_bucket[bucket_start:bucket_end]
param_out = param.detach().view(-1)[param_start:param_end]
if param_in.dtype == param_out.dtype:
# Just copy bytes if buffers have same type
param_in = param_in.view(torch.uint8)
param_out = param_out.view(torch.uint8)
buffers[(param.is_cuda, param.dtype)].append(
(param_in, param_out)
)
for (is_cuda, dtype), dtype_buffers in buffers.items():
fused_kernel_dtypes = (
self.param_sync_dtype,
torch.float32,
torch.float16,
torch.uint8,
)
if is_cuda and dtype in fused_kernel_dtypes:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
dummy_overflow_buf,
list(zip(*dtype_buffers)),
)
else:
for param_in, param_out in dtype_buffers:
param_out.copy_(param_in)
# Synchronize pipeline streams
for stream in self._pipeline_streams:
main_stream.wait_stream(stream)
return loss
def state_dict(self, gather_on_root=True):
"""Get dictionary containing optimizer state
Default behavior is to perform communication so that the
entire optimizer state is returned on the root rank in the
process group. In this case, all ranks in the process group
must enter this function and no value is returned on non-root
ranks.
Arguments:
gather_on_root (bool, optional): Gather state from all
ranks on the root rank (default: True)
"""
state_dict = super().state_dict()
if not gather_on_root:
return state_dict
# Export local state to byte string
state_bytes = io.BytesIO()
torch.save(state_dict, state_bytes)
state_bytes.seek(0)
state_bytes_view = state_bytes.getbuffer()
# Get data sizes on all ranks
local_state_size = len(state_bytes_view)
state_sizes = [None] * self.distributed_size
torch.distributed.all_gather_object(
state_sizes,
local_state_size,
group=self.process_group,
)
max_state_size = max(state_sizes)
# Construct workspace buffers
chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8
if self.distributed_rank == 0:
gathered_state_bytes = [state_bytes.getvalue()]
gathered_state_bytes.extend(bytearray(size) for size in state_sizes[1:])
gathered_chunks_buffers = [
torch.empty(
[chunk_size * self.distributed_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
else:
chunk_buffers = [
torch.empty(
[chunk_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Split data into chunks and gather on root rank
# Note: Assuming we are using the NCCL backend, communication
# must happen on the GPU. We split the data into fixed-size
# chunks so that the GPU memory usage is limited to
# (chunk_size * distributed_size) bytes.
# TODO: Avoid chunking with direct communication between CPUs
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(main_stream)
for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)):
stream_id %= self.pipeline_size
# Buffers for chunk
if self.distributed_rank == 0:
gathered_chunks = [
gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size]
for i in range(self.distributed_size)
]
else:
chunk = chunk_buffers[stream_id]
# Perform communication on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream):
# Copy to GPU
if self.distributed_rank != 0 and offset < local_state_size:
local_chunk_size = min(chunk_size, local_state_size-offset)
chunk[:local_chunk_size].copy_(
torch.frombuffer(
state_bytes_view,
dtype=torch.uint8,
count=local_chunk_size,
offset=offset,
),
non_blocking=True,
)
# Gather on root
if self.distributed_rank == 0:
if self._gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.gather(
gathered_chunks[0],
gathered_chunks,
dst=self._process_group_ranks[0],
group=self.process_group,
**no_copy_kwarg,
)
else:
torch.distributed.gather(
chunk,
dst=self._process_group_ranks[0],
group=self.process_group,
)
# Copy back to CPU
if self.distributed_rank == 0:
for rank in range(1, self.distributed_size):
if offset < state_sizes[rank]:
rank_chunk_size = min(chunk_size, state_sizes[rank]-offset)
torch.frombuffer(
gathered_state_bytes[rank],
dtype=torch.uint8,
count=rank_chunk_size,
offset=offset,
).copy_(
gathered_chunks[rank][:rank_chunk_size],
non_blocking=True,
)
# Synchronize GPU
for stream in self._pipeline_streams:
main_stream.wait_stream(stream)
main_stream.synchronize()
# Return gathered state data on root rank
if self.distributed_rank == 0:
return {'gathered_states': gathered_state_bytes}
else:
return None
def load_state_dict(self, state_dict):
"""Load optimizer state"""
# State dict contains state for all ranks
if 'gathered_states' in state_dict:
# Deallocate distributed optimizer state to reduce GPU
# memory usage
if 'buckets' in self.state:
del self.state['buckets']
# Get state for current rank and parse byte string
state_bytes = state_dict['gathered_states'][self.distributed_rank]
state_bytes = io.BytesIO(state_bytes)
state_dict = torch.load(state_bytes)
return super().load_state_dict(state_dict)
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
import torch.distributed.distributed_c10d as c10d
class DistributedFusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
step_supports_amp_scaling(boolean, optional): whether to use customized
gradient unscaling logic (default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
class AtomicCounter(object):
def __init__(self):
self.value = 0
self.order = []
import threading
self._lock = threading.Lock()
def add(self, idx):
with self._lock:
self.value += 1
self.order.append(idx)
def __init__(self, params,
lr=1e-3, bias_correction = True, grad_averaging=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., max_grad_norm=0.,
adam_w_mode=True, use_nvlamb=False,
step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False, verbose=False):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(DistributedFusedLAMB, self).__init__(params, defaults)
global fused_adam_cuda, distributed_lamb_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights
import amp_C
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
self._grad_averaging = grad_averaging
self._adam_w_mode = 1 if adam_w_mode else 0
self._use_nvlamb = use_nvlamb
self._step_supports_amp_scaling = step_supports_amp_scaling
self._is_accumulation_step = False
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather
self._verbose = verbose
self._L2_grad_norm = None
self._current_process_group = c10d._get_default_group()
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')
self._resume_from_checkpoint = False
self._step = torch.cuda.IntTensor([0])
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
if self._verbose:
print(f"creating new group {i}: {ranks}")
grp = torch.distributed.new_group(ranks=ranks)
if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
if self._verbose:
print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
if self._verbose:
print(f"created new group {i}")
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
#for ar_pg in self._ar_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
#torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
#for rs_pg in self._rs_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
#for ag_pg in self._ag_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream()
self._completion_st = torch.cuda.Stream()
self._step.record_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
self._one = torch.cuda.IntTensor([1])
self._first_step = True
self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
self._param_order = self.AtomicCounter()
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_offset = 0
p_i = 0
self._model_params = []
self._grad_accs = []
self._group_properties = []
for group in self.param_groups:
prev = None
beta1, beta2 = group['betas']
beta3 = 1.0 - beta1 if self._grad_averaging else 1.0
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
self._model_params.append(p)
self._group_properties.append((
weight_decay,
bias_correction,
beta1,
beta2,
beta3,
eps
))
p_grads_size = p.numel()
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._model_params)
self._grads_fp16, self._grads_fp32 = [], []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None:
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
self._lazy_init_stage1_done = True
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
self._grads_info = []
self._individual_flat_grads = []
for i, p in enumerate(self._model_params):
p_grads_size = p.numel()
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
# for the first iteration
self._do_overlapped_reduction(i, p)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params_fp16 = []
self._packed_flat_to_model_params_fp32 = []
self._model_params_num = len(self._model_params)
self._contrib_tensor_list = []
self._contrib_min_param_i, self._contrib_max_param_i = -1, -1
self._contrib_update_frag_for_norm = []
self._contrib_model_param_for_norm_fp16 = []
self._contrib_model_param_for_norm_fp32 = []
self._contrib_model_param_for_norm_is_fp16 = []
self._model_param_is_contrib = []
self._contrib_group_properties = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
if model_param_fragment.dtype == torch.float16:
self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
else:
self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
self._model_param_is_contrib.append(param_i)
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy
self._contrib_update_frag_for_norm.append(opti_state_u_fragment)
if p.dtype == torch.float16:
self._contrib_model_param_for_norm_fp16.append(p)
else:
self._contrib_model_param_for_norm_fp32.append(p)
self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False)
if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i
self._contrib_max_param_i = param_i
self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)
if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None
if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')
p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]
self._contrib_update_weights_tensor_list = [u, p, p_copy]
math_type = self._fp32_u.dtype
decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties))
self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda')
self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None
self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None
self._lazy_init_stage2_done = True
self.complete_reductions()
self._first_step = False
def set_is_accumulation_step(self, is_accumulation_step):
self._is_accumulation_step = is_accumulation_step
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Compute L2 grad norm
if block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt()
def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
elif self._contrib_model_param_for_norm_fp16 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
elif self._contrib_model_param_for_norm_fp32 is not None:
gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
return gnorm
def __compute_contrib_update_norm(self):
l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)
torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
l2_norm = torch.sqrt(l2_norm)
return l2_norm
def _pipeline_step(self):
global_scale = self.global_scale
max_grad_norm = self.defaults['max_grad_norm']
global_grad_norm = self.L2_grad_norm
# check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
# increment step counter if no overflow
self._step += is_finite
self._completion_st.wait_stream(torch.cuda.current_stream())
self._completion_st.wait_stream(self._l2_grad_norm_st)
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
param_norm = self.__compute_contrib_param_norm()
multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
self._overflow_buf,
self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
self._contrib_beta1,
self._contrib_beta2,
self._contrib_beta3,
self._contrib_bias_correction,
self._step,
self._contrib_epsilon,
self._adam_w_mode,
self._contrib_weight_decay,
global_scale,
global_grad_norm,
max_grad_norm)
upd_norm = self.__compute_contrib_update_norm()
multi_tensor_applier(self.multi_tensor_lamb_update_weights,
self._overflow_buf,
self._contrib_update_weights_tensor_list, # u, p, p_copy
param_norm,
upd_norm,
self._offsets,
self._lr,
self._contrib_weight_decay,
global_grad_norm,
self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale)
self._grads_fp16 = []
if len(self._grads_fp32) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale)
self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param):
if not self._is_accumulation_step:
# handle overlapped reductions
if param.dtype == torch.float16:
self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
self._grads_generated[param_i]=True
if not self._first_step and not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._first_step or self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None, grad_scaler=None):
loss = None
if closure is not None:
loss = closure()
self._pipeline_step()
if grad_scaler is not None:
found_inf = self._overflow_buf.float()
optimizer_state = grad_scaler._per_optimizer_states[id(self)]
current_device = torch.device('cuda', torch.cuda.current_device())
optimizer_state["found_inf_per_device"][current_device] = found_inf
self._completion_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._completion_st):
# Copy self._new_params to model params
with torch.no_grad():
if self._packed_flat_to_model_params_fp16 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp16)
if self._packed_flat_to_model_params_fp32 is not None:
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params_fp32)
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
return loss
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict = {}
state_dict['step'] = self._step
state_dict['fp32_p'] = self._fp32_p
state_dict['fp32_m'] = self._fp32_m
state_dict['fp32_v'] = self._fp32_v
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self._step = state_dict['step']
self._fp32_p = state_dict['fp32_p'].to(device="cuda")
self._fp32_m = state_dict['fp32_m'].to(device="cuda")
self._fp32_v = state_dict['fp32_v'].to(device="cuda")
self._resume_from_checkpoint = True
import torch
from apex.multi_tensor_apply import multi_tensor_applier
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer.
Designed only to wrap apex.contrib.optimizers.FusedAdam, FusedSGD.
Refer to apex.fp16_utils documents for more information.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
# loss.backward() becomes:
optimizer.backward(loss)
...
Example with dynamic loss scaling::
...
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary.
"""
def __init__(self,
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
verbose=True):
print("\nThis fp16_optimizer is designed to only work with apex.contrib.optimizers.*")
print("To update, use updated optimizers with AMP.")
# The fused optimizer does all the work. We need this layer for two reason:
# 1. maintain same user API from apex.fp16_utils
# 2. keep common stuff here in case we need to add new fused optimizer later
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer
self.fp16_groups = [] # model params
self.fp32_groups = [] # master weights
# iterate over param_groups
for param_group in self.optimizer.param_groups:
fp16_group = []
fp32_group = []
for p in param_group['params']:
fp16_group.append(p)
fp32_group.append(p.clone().float().detach())
self.fp16_groups.append(fp16_group)
self.fp32_groups.append(fp32_group)
param_group['params'] = fp32_group
if multi_tensor_applier.available:
import amp_C
self.overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
else:
raise RuntimeError('FP16_Optimizer requires cuda extensions')
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
if dynamic_loss_args is not None:
raise SystemError("Do not support dynamic loss scale args for now.")
self.dynamic_loss_scale = True
self.cur_scale = 2**16
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2
self.scale_window = 1000
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
self.verbose = verbose
def zero_grad(self, set_grads_to_None=True):
"""
Zero FP16 parameter grads.
"""
# FP32 grad should never exist.
# For speed, set model fp16 grad to None by default
for group in self.fp16_groups:
for p in group:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
def step(self, closure=None):
"""
Not supporting closure.
"""
fp16_grads = []
norm_groups = []
skip = False
for group in self.fp16_groups:
fp16_grad = []
for i, p in enumerate(group):
fp16_grad.append(p.grad)
fp16_grads.append(fp16_grad)
# nan check
self.overflow_buf.zero_()
for fp16_grad in fp16_grads:
if len(fp16_grad) > 0:
norm, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm,
self.overflow_buf,
[fp16_grad], True)
norm_groups.append(norm)
if self.overflow_buf.item() != 0:
skip = True
if skip:
self._update_scale(skip)
return
# norm is in fact norm*cur_scale
self.optimizer.step(grads=fp16_grads,
output_params=self.fp16_groups,
scale=self.cur_scale,
grad_norms=norm_groups)
self._update_scale(False)
return
def backward(self, loss):
"""
:attr:`backward` performs the following steps:
1. fp32_loss = loss.float()
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward()
def _update_scale(self, skip):
if self.dynamic_loss_scale:
if skip:
if self.verbose:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using dynamic loss scale of", self.cur_scale)
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
self.last_overflow_iter = self.cur_iter
else:
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
self.cur_scale *= self.scale_factor
else:
if skip:
print("\nGrad overflow on iteration", self.cur_iter)
print("Using static loss scale of", self.cur_scale)
self.cur_iter +=1
return
# Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
def _get_state(self):
return self.optimizer.state
def _set_state(self, value):
self.optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self.optimizer.param_groups
def _set_param_groups(self, value):
self.optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
def state_dict(self):
"""
Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
of the contained Pytorch optimizer.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
state_dict['fp32_groups'] = self.fp32_groups
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately.
# We choose option 2.
#
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']):
for _current, _saved in zip(current, saved):
_current.data.copy_(_saved.data)
import types
import torch
import importlib
from apex.multi_tensor_apply import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(FusedAdam, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms
if grads is None:
grads_group = [None]*len(self.param_groups)
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list:
grads_group = [grads]
else:
grads_group = grads
if output_params is None:
output_params_group = [None]*len(self.param_groups)
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0])!=list:
output_params_group = [output_params]
else:
output_params_group = output_params
if grad_norms is None:
grad_norms = [None]*len(self.param_groups)
for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms):
if grads_this_group is None:
grads_this_group = [None]*len(group['params'])
if output_params_this_group is None:
output_params_this_group = [None]*len(group['params'])
# compute combined scale factor for this group
combined_scale = scale
if group['max_grad_norm'] > 0:
# norm is in fact norm*scale
clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm']
if clip > 1:
combined_scale = clip * scale
bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
tensordevice = None
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad]
if output_param is not None:
pl.append(out_p)
for tl, t in zip(tensorlists, pl):
tl.append(t)
if tensordevice is None:
tensordevice = p.device
elif tensordevice != p.device:
raise RuntimeError('FusedAdam does not support use_mt with tensors on multiple device')
else:
with torch.cuda.device(p.device):
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
with torch.cuda.device(tensordevice):
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss
import torch
import importlib
import math
from apex.multi_tensor_apply import multi_tensor_applier
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_lamb" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm)
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
fused_lamb_cuda = importlib.import_module("fused_lamb_cuda")
self.multi_tensor_lamb = fused_lamb_cuda.lamb
else:
raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions')
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
def zero_grad(self):
if self.set_grad_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None
else:
super(FusedLAMB, self).zero_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dtype == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = 0.0, 0.0
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0].item()
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0].item()
# blend two grad norms to get global grad norm
global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16)
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
grad_averaging = 1 if group['grad_averaging'] else 0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []
for p in group['params']:
if p.grad is None:
continue
if p.grad.data.is_sparse:
raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if p.dtype == torch.float16:
g_16.append(p.grad.data)
p_16.append(p.data)
m_16.append(state['exp_avg'])
v_16.append(state['exp_avg_sq'])
elif p.dtype == torch.float32:
g_32.append(p.grad.data)
p_32.append(p.data)
m_32.append(state['exp_avg'])
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
if(len(g_16) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_16, p_16, m_16, v_16],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
global_grad_norm,
max_grad_norm)
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
[g_32, p_32, m_32, v_32],
group['lr'],
beta1,
beta2,
group['eps'],
group['step'],
bias_correction,
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
global_grad_norm,
max_grad_norm)
return loss
import types
import torch
from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
This version of fused SGD implements 2 fusions.
* Fusion of the SGD update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.contrib.optimizers.FusedSGD` should be used without AMP.
:class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad.
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
model = ...
model.half()
optimizer = apex.contrib.optimizers.FusedSGD(model.parameters())
# wrap with FP16_Optimizer
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
optimizer.zero_grad()
...
optimizer.backward(loss)
optmizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.contrib.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
grads (list of tensors, optional): weight gradient to use for the
optimizer update. If gradients have type torch.half, parameters
are expected to be in type torch.float. (default: None)
output_params (list of tensors, optional): A reduced precision copy
of the updated weights written out in addition to the regular
updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
if hasattr(self, "_amp_stash"):
raise RuntimeError('apex.contrib.optimizers.FusedSGD should not be used with AMP.')
loss = None
if closure is not None:
loss = closure()
if grads is None:
raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \
with apex.contrib.optimizers.FP16_Optimizer \
which provides grads.')
# backward compatibility
# assuming a list/generator of parameter means single group
elif isinstance(grads, types.GeneratorType):
grads_group = [grads]
elif type(grads[0])!=list:
grads_group = [grads]
else:
grads_group = grads
if output_params is None:
raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \
with apex.contrib.optimizers.FP16_Optimizer \
which provides output_params.')
elif isinstance(output_params, types.GeneratorType):
output_params_group = [output_params]
elif type(output_params[0])!=list:
output_params_group = [output_params]
else:
output_params_group = output_params
for group, grads_this_group, output_params_this_group in zip(self.param_groups,
grads_group,
output_params_group):
if grads_this_group is None or output_params_this_group is None:
raise RuntimeError('apex.contrib.optimizers.FusedSGD only works \
when all parameters require grad.')
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
lr = group['lr']
first_runs = [True, True]
# output_params_this_group: original weights (either fp16 or fp32)
# group['params']: master weights (fp32)
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# fp32, fp32, fp32, No
fp32_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float32]
fp32_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float32]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
fp32_set = [fp32_grads, fp32_params, fp32_momentums]
# fp16, fp32, fp32, Yes
fp16_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float16]
fp32_from_fp16_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_params = [p1 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16]
fp16_set = [fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_params]
launch_sets = [fp16_set, fp32_set]
for launch_set, first_run in zip(launch_sets, first_runs):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
lr,
nesterov,
first_run,
self.wd_after_momentum,
1.0/scale)
return loss
from .peer_memory import PeerMemoryPool
from .peer_halo_exchanger_1d import PeerHaloExchanger1d
import torch
from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d
import peer_memory_cuda as pm
# How to run:
# torchrun --nproc_per_node <num-GPU> <this-python-prog>
# <num-GPU> must be a power of 2 greater than 1.
# Output of this function is used as ground truth in module tests.
def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split):
if explicit_nhwc:
if H_split:
_, Hp, _, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,half_halo:2*half_halo,:,:]
top_inp_halo = y[:,:half_halo,:,:]
btm_out_halo = y[:,H:H+half_halo,:,:]
btm_inp_halo = y[:,H+half_halo:H+2*half_halo,:,:]
else:
_, _, Wp, _ = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,W:W+half_halo,:]
btm_inp_halo = y[:,:,W+half_halo:W+2*half_halo,:]
else:
if H_split:
_, _, Hp, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,H:H+half_halo,:]
btm_inp_halo = y[:,:,H+half_halo:H+2*half_halo,:]
else:
_, _, _, Wp = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,:,half_halo:2*half_halo]
top_inp_halo = y[:,:,:,:half_halo]
btm_out_halo = y[:,:,:,W:W+half_halo]
btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo]
mf = torch.channels_last if y.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format
top_out_halo = top_out_halo.contiguous()
btm_out_halo = btm_out_halo.contiguous()
top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(top_inp_halos, top_out_halo)
btm_inp_halos = [torch.empty_like(btm_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(btm_inp_halos, btm_out_halo)
top_rank = (peer_rank + peer_group_size - 1) % peer_group_size
btm_rank = (peer_rank + 1) % peer_group_size
if peer_rank == 0:
top_inp_halo.zero_()
else:
top_inp_halo.copy_(btm_inp_halos[top_rank].to(memory_format=mf))
if peer_rank == peer_group_size-1:
btm_inp_halo.zero_()
else:
btm_inp_halo.copy_(top_inp_halos[btm_rank].to(memory_format=mf))
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
if memory_format == 1:
# 1 -> explicit nhwc
explicit_nhwc = True
if H_split:
y = torch.randn([1,H+2*half_halo,W,C], dtype=dtype, device='cuda')
ym = y[:,half_halo:H+half_halo,:,:]
else:
y = torch.randn([1,H,W+2*half_halo,C], dtype=dtype, device='cuda')
ym = y[:,:,half_halo:W+half_halo,:]
else:
# 2 -> native nhwc
# 3 -> nchw
explicit_nhwc = False
if H_split:
y = torch.randn([1,C,H+2*half_halo,W], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,half_halo:H+half_halo,:]
else:
y = torch.randn([1,C,H,W+2*half_halo], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,:,half_halo:W+half_halo]
y3 = y.clone()
list_y = []
for step in range(num_steps):
halo_ex(y, H_split, explicit_nhwc, numSM)
list_y.append(y.clone())
y.copy_(y3)
halo_ex.peer_pool.reset()
torch.distributed.barrier()
y2 = y3.clone()
list_y2 = []
for step in range(num_steps):
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
list_y2.append(y2.clone())
y2.copy_(y3)
is_equal = [torch.all(torch.eq(yy,yy2)) for yy,yy2 in zip(list_y,list_y2)]
is_equal = torch.tensor(is_equal, dtype=torch.bool)
is_equal = torch.all(is_equal)
if peer_rank == 0:
if memory_format == 1:
memory_format_str = "explicit_nhwc"
elif memory_format == 2:
memory_format_str = "native nhwc"
elif memory_format == 3:
memory_format_str = "nchw"
else:
memory_format_str = "???"
if is_equal:
print("SUCCESS : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
else:
print("FAILURE : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
# peer memory flag sync relies on there being at least one barrier per step
torch.distributed.barrier()
def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Hr = 8*world_size
Hp = ((H + Hr - 1) // Hr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True, num_steps)
def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Wr = 8*world_size
Wp = ((W + Wr - 1) // Wr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False, num_steps)
def main():
# for this trivial example peer_rank == rank and peer_group_size == world_size
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
peer_ranks = [i for i in range(world_size)]
pool = PeerMemoryPool(64*1024, 2*1024*1024, peer_ranks)
num_steps = 100
half_halo = 1
halo_ex = PeerHaloExchanger1d(peer_ranks, rank, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
if __name__ == "__main__":
main()
import torch
from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory_cuda as pm
class PeerHaloExchanger1d:
def __init__(self, ranks, rank_in_group, peer_pool, half_halo):
self.peer_group_size = len(ranks)
self.ranks = ranks
self.peer_rank = rank_in_group
self.low_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
self.high_neighbor = (self.peer_rank + 1) % self.peer_group_size
self.low_zero = True if self.peer_rank == 0 else False
self.high_zero = True if self.peer_rank == self.peer_group_size - 1 else False
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.half_halo = half_halo
def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=False):
channels_last = y.is_contiguous(memory_format=torch.channels_last) and not explicit_nhwc
if H_split:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*self.half_halo
low_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True)
low_inp_halo = y[:,:self.half_halo,:,:]
high_out_halo = y[:,H:H+self.half_halo,:,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True)
high_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*self.half_halo
low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True)
low_inp_halo = y[:,:,:self.half_halo,:]
high_out_halo = y[:,:,H:H+self.half_halo,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True)
high_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*self.half_halo
low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True)
low_inp_halo = y[:,:,:self.half_halo,:]
high_out_halo = y[:,:,W:W+self.half_halo,:]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True)
high_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*self.half_halo
low_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True)
low_inp_halo = y[:,:,:,:self.half_halo]
high_out_halo = y[:,:,:,W:W+self.half_halo]
high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True)
high_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
pm.push_pull_halos_1d(
diagnostics, explicit_nhwc, numSM,
self.low_zero, low_out_halo, low_tx[self.peer_rank], high_tx[self.low_neighbor], low_inp_halo,
self.high_zero, high_out_halo, high_tx[self.peer_rank], low_tx[self.high_neighbor], high_inp_halo,
self.signals[self.low_neighbor], self.signals[self.high_neighbor], self.signals[self.peer_rank]
)
import torch
import numpy as np
import peer_memory_cuda as pm
class PeerMemoryPool(object):
def __init__(self, static_size, dynamic_size, peer_ranks=None):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
ngpus = min(torch.cuda.device_count(), world_size)
peer_group_size = ngpus
peer_group = rank // ngpus
peer_rank_base = peer_group * ngpus
peer_rank = rank - peer_rank_base
if peer_ranks is None:
peer_ranks = [i+peer_rank_base for i in range(peer_group_size)]
peer_rank_start = peer_rank_base
peer_rank_end = peer_rank_start + peer_group_size - 1
for pr in peer_ranks:
assert(pr >= peer_rank_start and pr <= peer_rank_end), "%d :: peer_rank %d not on same node (ranks=[%d,%d])" % (rank, pr, peer_rank_start, peer_rank_end)
self.alignment = 256
self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment
# allocate giant pool of device memory
self.raw = pm.allocate_raw(self.static_size+self.dynamic_size)
# exchange peer pointers with nccl
raw_ipc = pm.get_raw_ipc_address(self.raw).cuda()
peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]
torch.distributed.all_gather(peer_raw_ipcs, raw_ipc)
peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu()
# extract IPC pointers for ranks on same node
peer_raw = pm.get_raw_peers(peer_raw_ipcs[peer_rank_base:peer_rank_base+ngpus], peer_rank, self.raw)
self.peer_raw = [peer_raw[peer_rank-peer_rank_base] for peer_rank in peer_ranks]
self.static_offset = 0
self.dynamic_offset = 0
self.peer_ranks = peer_ranks
def __del__(self):
pm.free_raw(self.raw)
def reset(self):
self.dynamic_offset = 0
def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic):
nels = np.prod(shape)
if dtype == torch.float16:
elem_size = 2
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_half(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_half(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.float32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_float(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_float(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.int32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_int(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_int(pr + start, shape, channels_last) for pr in self.peer_raw]
else:
assert(False), "dtype %s not supported" % (str(dtype))
# Introduction to ASP
This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
## Importing ASP
```
from apex.contrib.sparsity import ASP
```
## Initializing ASP
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:
```
ASP.prune_trained_model(model, optimizer)
```
In the context of a typical PyTorch training loop, it might look like this:
```
ASP.prune_trained_model(model, optimizer)
x, y = DataLoader(args)
for epoch in range(epochs):
y_pred = model(x)
loss = loss_function(y_pred, y)
loss.backward()
optimizer.step()
torch.save(...)
```
The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step.
## Generate a Sparse Network
The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode.
```
(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern.
(2) Fine-tune the pruned model with optimization method and hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model.
(3) (If required) Quantize the model.
```
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model
from apex.contrib.sparsity import ASP
ASP.prune_trained_model(model, optimizer) #pruned a trained model
x, y = DataLoader(args)
for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model
y_pred = model(x)
loss = criterion(y_pred, y)
lr_scheduler.step()
loss.backward()
optimizer.step()
torch.save(...) # saves the pruned checkpoint with sparsity masks
```
## Non-Standard Usage
If your goal is to easily perpare a network for accelerated inference, please follow the recipe above. However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method:
```
ASP.compute_sparse_masks()
```
A more thorough example can be found in `./test/toy_problem.py`.
## Advanced Usage: Channel Permutation
We introduce channel permutations as an advanced method to maximize the accuracy of structured sparse networks. By permuting weight matrices along their channel dimension and adjusting the surrounding layers appropriately, we demonstrate accuracy recovery for even small, parameter-efficient networks, without affecting inference run-time.
The final accuracy has a strong relationship with the quality of permutations. We provide the default algorithms to search for high-quality permutations. The permutation search process can be accelerated by the Apex CUDA extension: `apex.contrib.sparsity.permutation_search_kernels`
If you want to use the GPU to accelerate the permutation search process, we recommend installing Apex with permutation search CUDA extension via
```
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--permutation_search" ./
```
If you want to disable the permutation search process, please pass the `allow_permutation=False` to `init_model_for_pruning` function. For example:
```
ASP.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False, allow_permutation=False)
```
Please notice, when using multi-GPUs we should set the identical random seed for all GPUs to make sure the same results generated in permutation search. The library has implemented the `set_identical_seed` function in `permutation_lib.py`, and be called in ASP library. We still suggest the users to set the identical random seed when using multi-GPUs in their code, the example code is as follows:
```
import torch
import numpy
import random
torch.manual_seed(identical_seed)
torch.cuda.manual_seed_all(identical_seed)
numpy.random.seed(identical_seed)
random.seed(identical_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
```
## Reference Papers
More details about sparsity support on the NVIDIA Ampere GPU with Sparse Tensor Cores can refer to our [white paper](https://arxiv.org/abs/2104.08378).
```
@article{mishra2021accelerating,
title={Accelerating sparse deep neural networks},
author={Mishra, Asit and Latorre, Jorge Albericio and Pool, Jeff and Stosic, Darko and Stosic, Dusan and Venkatesh, Ganesh and Yu, Chong and Micikevicius, Paulius},
journal={arXiv preprint arXiv:2104.08378},
year={2021}
}
```
The details about sparsity with permutation can refer to our [paper](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html) published in *Thirty-fifth Conference on Neural Information Processing Systems* (**NeurIPS 2021**):
```
@article{pool2021channel,
title={Channel Permutations for N: M Sparsity},
author={Pool, Jeff and Yu, Chong},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
```
from .sparse_masklib import create_mask
from .asp import ASP
import types
import torch
from .sparse_masklib import create_mask
from .permutation_lib import Permutation
torchvision_imported=True
try:
import torchvision
except ImportError:
print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported=False
import json
import os
import string
import time
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = []
for name, mod in model.named_modules():
if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names:
if allowed_layer_names is not None and name not in allowed_layer_names:
continue
eligible_modules_list.append((name, mod))
return eligible_modules_list
class ASP:
__model = None
__verbosity = 0
__optimizer = None
__sparse_parameters = []
__calculate_mask = None
__allow_permutation = True
__all_parameters = []
__save_permutation_graph = False
__permutation_output_dir = ''
@classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False, custom_layer_dict={},
allow_permutation=True):
"""Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA.
If you are starting with a fresh model:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it.
If you are starting from a checkpoint:
model = ...
ASP.init_model_for_pruning(model, mask_calculator, ...)
torch.load(...)
if (training) ASP.init_optimizer_for_pruning(optimizer)
Arguments:
model The model
mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib.
verbosity Integer controling verbosity level.
0 -> Only errors.
1 -> Errors and warnings.
2 -> Errors, warnings and info.
3 -> Errors, warnings, info and debug.
whitelist Module types approved for sparsity.
allowed_layer_names If not None, only layer names that appear in this list are considered for sparsity.
disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity.
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning.
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe.
"""
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator #user defined function
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if torchvision_imported:
print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']}
else:
sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']}
if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune
sparse_parameter_list.update(custom_layer_dict)
whitelist += list(custom_layer_dict.keys())
for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim
for module_name, module in model.named_modules():
module_type_str = str(type(module)).split("\'")[1]
if module_type_str == 'torch.nn.modules.container.Sequential' or module_type_str.startswith('torchvision.models'):
# filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'
continue
for p_name, p in module.named_parameters():
cls.__all_parameters.append((module_name, module, p_name, p))
if module_type_str == 'torch.nn.modules.batchnorm.BatchNorm2d':
# need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters
module_mean_name = module_name + '.running_mean'
module_var_name = module_name + '.running_var'
for param_key in model.state_dict():
if module_mean_name == param_key or module_var_name == param_key:
cls.__all_parameters.append((module_name, module, param_key.split(".")[-1], model.state_dict()[param_key]))
# add the __permutation_output_dir field to save the intermediate results for permutation
cls.__permutation_output_dir = '.'
# Set the corresponding params from ASP class to the Permutation class
Permutation.set_permutation_params_from_asp(cls.__model, cls.__sparse_parameters, cls.__all_parameters)
# Set the identical random seed for all GPUs to make sure the same results generated in permutation search
Permutation.set_identical_seed()
# find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)]
for p_name, p in module.named_parameters():
if p_name in sparse_parameters and p.requires_grad:
# check for NVIDIA's TC compatibility: we check along the horizontal direction
if p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #User defines FP32 and APEX internally uses FP16 math
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue
if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C
print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
continue
if cls.__verbosity >= 3:
print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype)))
mask = torch.ones_like(p).bool()
buffname = p_name.split(".")[-1] # buffer names cannot contain "."
module.register_buffer('__%s_mma_mask' % buffname, mask)
if allow_recompute_mask:
pruned = torch.zeros_like(p).cpu()
module.register_buffer('__%s_mma_pruned_p' % buffname, pruned)
else:
pruned = None
cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned))
else:
if cls.__verbosity >= 3:
print("[ASP] Not sparsifying %s::%s of size=%s and type=%s" % (module_name, p_name, str(p.size()), str(p.dtype)))
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module)
@classmethod
def already_init_asp_model(cls):
"""Call this method to check whether ASP has been initialized already.
"""
if cls.__model is None:
if cls.__verbosity >= 3:
print("[ASP] ASP has not been initialized.")
return False
else:
if cls.__verbosity >= 3:
print("[ASP] ASP has been initialized already.")
return True
@classmethod
def init_optimizer_for_pruning(cls, optimizer):
"""Call this method to monkey patch optimizer step function so that masks can be applied to
gradients and weights during training.
You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...)
"""
assert (cls.__optimizer is None), "ASP has initialized optimizer already."
assert (cls.__calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning."
# store pointer to original optimizer step method
cls.__optimizer = optimizer
cls.__optimizer.__step = optimizer.step
def __step(opt_self, *args, **kwargs):
# prune gradients before step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if p.grad is not None: #thx pjudd
p.grad.mul_(mask)
# call original optimizer step method
rval = opt_self.__step(*args, **kwargs)
# prune parameters after step method
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
p.mul_(mask)
return rval
cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)
@classmethod
def compute_sparse_masks(cls):
"""Call this method to enable sparsity.
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
with torch.no_grad():
if cls.__allow_permutation:
# Step 1: use the Torch.FX library to build the graph
# Step 2: permutation search with the customized kernel
# Notice: need to use the single GPU to build the Torch.FX graph
# The simplest without user intervention:
# A. try to import with the distributed mode of the original model
# B. if meet the error, import with the none-distributed mode of the original model
start_time_build_offline_permutation_graph = time.perf_counter()
try:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model.module, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on distributed model.")
except AttributeError:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on none-distributed model.")
duration_build_offline_permutation_graph = time.perf_counter() - start_time_build_offline_permutation_graph
print("[compute_sparse_masks] Take {:.4f} seconds to finish build_offline_permutation_graph function.".format(duration_build_offline_permutation_graph))
# Step 3: off-line permutation to avoid the runtime overhead in deployment
if success_in_build_offline_permutation_graph:
start_time_apply_offline_permutation = time.perf_counter()
try:
Permutation.apply_offline_permutation(cls.__model.module, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on distributed model.")
except AttributeError:
Permutation.apply_offline_permutation(cls.__model, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on none-distributed model.")
duration_apply_offline_permutation = time.perf_counter() - start_time_apply_offline_permutation
print("[compute_sparse_masks] Take {:.4f} seconds to finish apply_offline_permutation function.\n".format(duration_apply_offline_permutation))
else:
print("[compute_sparse_masks] skip applying offline permutation because there is no valid offline_permutation_fx_graph.")
# Finally, permutation search and off-line permutation is done, give the model back to ASP to generate the normal structured sparse mask
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())
mask.set_(cls.__calculate_mask(p))
if pruned is not None: # stow away pruned weights to cpu
pruned.set_((p * (~mask)).cpu())
p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2:
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0-100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))
@classmethod
def restore_pruned_weights(cls):
"""Call this method to disable sparsity and restore all weights.
This will only work if init(...) was called with allow_recompute=True.
"""
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel():
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())
mask.fill_(1)
pruned.zero_()
if cls.__verbosity >= 2:
print("[ASP] Disabled sparsity for %s::%s (dense weights restored)" % (module_name, p_name))
@classmethod
def is_sparsity_enabled(cls):
"""Call this method to determine if sparsity is enabled in the model.
The typical use case is right after checkpoint has been loaded.
"""
total,sp100,sp50 = 0,0,0
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
total += 1
mask_sum = mask.sum()
mask_numel = mask.numel()
if mask_sum == mask_numel:
sp100 += 1
elif mask_sum*2 == mask_numel:
sp50 += 1
assert (total == sp100 or total == sp50), "Inconsistent model sparsity"
if total == sp100:
return False
elif total == sp50:
return True
@classmethod
def prune_trained_model(cls, model, optimizer):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()
@classmethod
def set_permutation_saving_params(cls, allow_permutation=True, save_permutation_graph=False, permutation_output_dir='.'):
"""This function is used to set the permutation saving related parameters in ASP class and inside of the Permutation class."""
print("\n[ASP][set_permutation_saving_param] Set permutation saving related parameters")
print("\n[set_permutation_saving_param] Set permutation saving related parameters")
cls.__allow_permutation = allow_permutation
print("[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation))
cls.__save_permutation_graph = save_permutation_graph
print("[set_permutation_saving_param]\t Save permutation graphs: {}".format(cls.__save_permutation_graph))
cls.__permutation_output_dir = permutation_output_dir
print("[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(cls.__permutation_output_dir))
Permutation.set_permutation_saving_params(allow_permutation, save_permutation_graph, permutation_output_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment