Unverified Commit 96850dfa authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #80 from ROCmSoftwarePlatform/IFU-master-2022-07-29

IFU-master-2022-07-29
parents 87fc4125 cc5f83b5
...@@ -10,11 +10,6 @@ from .fast_encdec_multihead_attn_func import fast_encdec_attn_func ...@@ -10,11 +10,6 @@ from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, "_jit_set_profiling_executor"):
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, "_jit_set_profiling_mode"):
torch._C._jit_set_profiling_mode(False)
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training): def jit_dropout_add(x, residual, prob, is_training):
......
...@@ -10,11 +10,6 @@ from .fast_self_multihead_attn_func import fast_self_attn_func ...@@ -10,11 +10,6 @@ from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, "_jit_set_profiling_executor"):
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, "_jit_set_profiling_mode"):
torch._C._jit_set_profiling_mode(False)
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training): def jit_dropout_add(x, residual, prob, is_training):
......
import collections
import contextlib
import enum
import importlib
import inspect
import io
import math import math
import threading
import torch import torch
import importlib
import amp_C import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank
import torch.distributed.distributed_c10d as c10d def _round_to_multiple(number, multiple, round_up=True):
"""Assumes arguments are positive integers"""
return (number+multiple-1 if round_up else number) // multiple * multiple
class DistributedFusedAdam(torch.optim.Optimizer): class DistributedFusedAdam(torch.optim.Optimizer):
"""AdamW optimizer with ZeRO algorithm.
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_. This implements the ZeRO-2 algorithm, which distributes the
optimizer state and gradients between parallel processes. In
particular, the parameters are flattened, grouped into fixed-size
buckets, and the optimizer state for each bucket is sharded over
the parallel processes. Options are provided to overlap the
gradient synchronization with the backward pass compute.
Adam was proposed in `Adam: A Method for Stochastic
Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_,
and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion
Parameter Models`_.
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts
parameter groups. defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing betas (Tuple[float, float], optional): coefficients used for
running averages of gradient and its square. (default: (0.9, 0.999)) computing running averages of gradient and its square.
(default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8) numerical stability. (default: 1e-8)
eps_inside_sqrt (boolean, optional): in the 'update parameters' step, weight_decay (float, optional): weight decay (L2 penalty)
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
step_supports_amp_scaling(boolean, optional): whether to use customized
gradient unscaling logic (default: True)
num_process_groups (integer, optional): number of process groups in
the app (default: 1)
current_process_group (object, optional): the process group to work on
(default: None)
process_group_id (integer, optional): process group id (default: 0)
process_group_size (integer, optional): size of process group
(default: 0) (default: 0)
clip_grad_norm (boolean, optional): whether to handle gradient clipping amsgrad (boolean, optional): whether to use the AMSGrad
variant of this algorithm from the paper
`On the Convergence of Adam and Beyond`_ (default: False).
This is not yet supported.
dtype (torch.dtype, optional): datatype for optimizer state
(default: torch.float32)
grad_sync_dtype (torch.dtype, optional): datatype for gradient
synchronization (default: same as dtype)
param_sync_dtype (torch.dtype, optional): datatype for
parameter synchronization (default: same as dtype)
device (torch.device, optional): device for optimizer state
(default: cuda). Currently only supports GPU with one GPU
per process.
process_group (torch.distributed.ProcessGroup, optional):
parallel processes participating in optimizer (default:
default group in torch.distributed). This group is
interpreted as a 2D grid with dimensions
distributed_size x redundant_size.
distributed_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to distribute optimizer
state over (default: same as process_group)
redundant_process_group (torch.distributed.ProcessGroup,
optional): parallel processes to replicate optimizer state
over (default: group only containing calling process)
average_grad_sync (bool, optional): whether to use average
reduction for gradient synchronization rather than sum
(default: True) (default: True)
model_parallel (boolean, optional): whether model parallelism is used overlap_grad_sync(boolean, optional): whether to overlap
(default: False) gradient synchronization with backward pass compute
(default: True)
bucket_cap_mb (float, optional): bucket size in megabytes
(default: 100)
pipeline_size (int, optional): number of buckets to
synchronize simultaneously (default: 2)
contiguous_grad_buffer (bool, optional): allocate gradient
buckets out of a large persistent buffer (default: False).
This allows individual parameter gradients to be accessed
externally (see grad_buffer_view function). It also
maximizes memory usage and may prevent overlapping
communication and compute.
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
.. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models:
https://arxiv.org/abs/1910.02054
""" """
def __init__(self, params, class ParameterFragment:
lr=1e-3, bias_correction=True, betas=(0.9, 0.999), """Buffer ranges for a parameter fragment
eps=1e-8, eps_inside_sqrt=False,
weight_decay=0., max_grad_norm=0., Describes corresponding regions in parameter buffer and
amsgrad=False, flat_mt=False, parameter bucket.
overlap_reductions=True,
compute_L2_grad_norm=False, """
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, def __init__(
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, self,
predivide=True, e5m2_allgather=False, param_group_id,
do_not_flatten_model=False, param_id,
step_supports_amp_scaling=True, bucket_id,
num_process_groups=1, param_range,
current_process_group=None, bucket_range,
process_group_id=0, in_local_shard,
process_group_size=0, shard_range,
clip_grad_norm=True, shard_bucket_range,
model_parallel=False): shard_param_range,
):
# Parameter group index
self.param_group_id = param_group_id
# Parameter index within parameter group
self.param_id = param_id
# Bucket index
self.bucket_id = bucket_id
# Range within flattened parameter buffer
self.param_range = param_range
# Range within bucket
self.bucket_range = bucket_range
# Whether fragment is in local shard of bucket
self.in_local_shard = in_local_shard
# Range within local shard
self.shard_range = shard_range
# Range of local fragment shard within bucket
self.shard_bucket_range = shard_bucket_range
# Range of local fragment shard within parameter
self.shard_param_range = shard_param_range
class StateBucket:
def __init__(self, shard_size, dtype, device):
"""Optimizer state for a bucket"""
# Buffer ranges corresponding to parameter fragments
self.fragments = []
# Local shard of parameters
self.params_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of first moment estimate
self.exp_avg_shard = torch.zeros([shard_size], dtype=dtype, device=device)
# Local shard of second moment estimate
self.exp_avg_sq_shard = torch.zeros([shard_size], dtype=dtype, device=device)
class GradientStatus(enum.Enum):
"""Status of gradients within a bucket"""
# Gradients are ready to use
READY = enum.auto()
# Bucket is partially filled with unreduced gradients
PARTIALLY_FILLED = enum.auto()
# Bucket is fully filled with unreduced gradients
FULLY_FILLED = enum.auto()
# Asynchronous reduction is in progress
SYNCING = enum.auto()
class GradientBucket:
"""Gradient buffers and state for a bucket"""
def __init__(self):
# Local shard of gradients
self.grads_shard = None
# Local contribution to gradients
self.grads_bucket = None
# Buffer for gradient reduce-scatter
self.sync_grads_shard = None
# Status of gradients
self.status = DistributedFusedAdam.GradientStatus.READY
# Request object for asynchronous communication
self.sync_request = None
def sync_wait(self):
"""Wait for asynchronous communication to finish"""
if self.sync_request is not None:
self.sync_request.wait()
self.sync_request = None
_step_supports_amp_scaling = True
def __init__(self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.,
amsgrad=False,
dtype=torch.float32,
grad_sync_dtype=None,
param_sync_dtype=None,
device='cuda',
process_group=None,
distributed_process_group=None,
redundant_process_group=None,
average_grad_sync=True,
overlap_grad_sync=True,
bucket_cap_mb=100,
pipeline_size=2,
contiguous_grad_buffer=False,
):
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay)
super(DistributedFusedAdam, self).__init__(params, defaults)
# Adam options
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
# Datatype options
if grad_sync_dtype is None:
grad_sync_dtype = dtype
if param_sync_dtype is None:
param_sync_dtype = dtype
supported_dtypes = [
(torch.float32, torch.float16),
(torch.float32, torch.float32),
]
if (dtype, grad_sync_dtype) not in supported_dtypes:
raise RuntimeError(
'Invalid dtypes for DistributedFusedAdam '
f'(dtype={dtype}, '
f'grad_sync_dtype={grad_sync_dtype}, '
f'param_sync_dtype={param_sync_dtype}))')
if device != 'cuda':
raise RuntimeError('DistributedFusedAdam only supports GPU')
self.dtype = dtype
self.grad_sync_dtype = grad_sync_dtype
self.param_sync_dtype = param_sync_dtype
self.device = device
# Process groups
self.process_group = (
_get_default_group()
if process_group is None
else process_group
)
self.distributed_process_group = (
self.process_group
if distributed_process_group is None
else distributed_process_group
)
self.redundant_process_group = redundant_process_group
self.process_group_size = torch.distributed.get_world_size(self.process_group)
self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group)
self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group)
self.redundant_size = (
1
if self.redundant_process_group is None
else torch.distributed.get_world_size(self.redundant_process_group)
)
if self.process_group_size != self.distributed_size * self.redundant_size:
raise RuntimeError(
'Invalid process group configuration '
f'(process group size = {self.process_group_size}, '
f'distributed process group size = {self.distributed_size}, '
f'redundant process group size = {self.redundant_size})'
)
try:
self._process_group_ranks = [
_get_global_rank(self.process_group, local_rank)
for local_rank in range(self.distributed_size)
]
except:
self._process_group_ranks = list(range(self.distributed_size))
# Use average reduction for grad sync
self.average_grad_sync = average_grad_sync
# Copy param grads to bucket as soon as available
self.greedy_grad_copy = True
# Synchronize grad buckets as soon as all grads are available
self.overlap_grad_sync = overlap_grad_sync
# Number of buckets to synchronize at a time
self.pipeline_size = pipeline_size
# Allocate contiguous buffer for gradients
self.contiguous_grad_buffer = contiguous_grad_buffer
# Determine bucket sizes
dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
self.alignment = 128 // dtype_size
bucket_size = 1024*1024*bucket_cap_mb / dtype_size
shard_size = int(bucket_size / self.distributed_size)
shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False)
shard_size = max(shard_size, self.alignment)
bucket_size = shard_size * self.distributed_size
self.bucket_size = bucket_size
self.shard_size = shard_size
# Load CUDA kernels
global fused_adam_cuda, distributed_adam_cuda global fused_adam_cuda, distributed_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
distributed_adam_cuda = importlib.import_module("distributed_adam_cuda") distributed_adam_cuda = importlib.import_module("distributed_adam_cuda")
self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
if amsgrad: # Optimizer state
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') self.state['buckets'] = []
self.state['step'] = 0
defaults = dict(lr=lr, bias_correction=bias_correction, # Objects for gradient synchronization
betas=betas, eps=eps, weight_decay=weight_decay, self._grads_buckets = collections.defaultdict(self.GradientBucket)
max_grad_norm=max_grad_norm) self._grads_generated = set()
super(DistributedFusedAdam, self).__init__(params, defaults) self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]
# Misc # Divide gradients by factor before optimizer step. Used for
self.eps_mode = 0 if eps_inside_sqrt else 1 # grad clipping and gradient scaler.
self._overflow_buf = torch.cuda.IntTensor([0]) self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
self._has_overflow = False # Norm of parameter gradients. Used for gradient clipping and
self._step_supports_amp_scaling = step_supports_amp_scaling # gradient scaler.
self._last_step = False self._grad_norm = None
self._overlap_reductions = overlap_reductions
self._global_scale = None # Check if collectives have no_copy option
self._num_blocks = dwu_num_blocks self._reduce_scatter_no_copy = (
self._num_chunks = dwu_num_chunks 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
self._predivide = predivide )
self._e5m2_allgather = e5m2_allgather self._all_gather_no_copy = (
self._do_not_flatten_model = do_not_flatten_model 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
self._compute_L2_grad_norm = compute_L2_grad_norm )
self._L2_grad_norm = None self._gather_no_copy = (
self._flat_mt = flat_mt 'no_copy' in inspect.getfullargspec(torch.distributed.gather).args
self._init_done = False )
self._resume_from_checkpoint = False
self._step = 0 # Attach hooks for gradient synchronization
self._register_post_backward_hooks()
# Process group related
self._clip_grad_norm = clip_grad_norm def _register_post_backward_hooks(self):
self._model_parallel = model_parallel """Attach hooks for gradient synchronization
self._num_process_groups = num_process_groups
self._current_process_group = current_process_group if current_process_group is not None else c10d._get_default_group() Optimizer state for parameters are initialized lazily as they
self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys()) are encountered in the backward pass.
self._process_group_id = process_group_id
self._process_group_size = torch.cuda.device_count() if process_group_size <= 0 else process_group_size """
self._world_size = self._process_group_size # world: the current process group self._num_grads = 0
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size grad_buffer_size = 0
self._num_groups = self._world_size // self._group_size self._lock = threading.Lock()
self._global_rank = torch.distributed.get_rank()
self._world_rank = self._global_rank // self._num_process_groups
self._group_rank = self._world_rank % self._group_size
#print("world_size:", self._world_size, ", group_size:", self._group_size, ", num_groups:", self._num_groups, ", global_rank:", self._global_rank, ", world_rank:", self._world_rank, ", group_rank:", self._group_rank)
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
# Master weight, moment, gradient buffers
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
def _first_step_init(self):
p_offset = 0
p_i = 0
self._model_params = []
self._grads_info = []
self._grad_accs = [] self._grad_accs = []
self._group_properties = [] for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
torch.distributed.broadcast(
param,
src=self._process_group_ranks[0],
group=self.process_group,
)
if param.requires_grad:
self._num_grads += 1
# Callback after gradient is generated
def wrapper(p, p_group_id, p_id):
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
def reduction_hook(*unused):
with self._lock:
if 'fragments' not in self.state[p]:
self._init_param_state(p, p_group_id, p_id)
if self.greedy_grad_copy:
self._grad_copy(p)
if self.overlap_grad_sync:
self._try_start_bucket_grad_sync(
params=[p],
ignore_last_bucket=True,
)
grad_acc.register_hook(reduction_hook)
self._grad_accs.append(grad_acc)
wrapper(param, param_group_id, param_id)
# Gradient size, with padding for alignment
grad_size = _round_to_multiple(param.numel(), self.alignment)
grad_buffer_size += grad_size
# Allocate contiguous gradient buffer if needed
if self.contiguous_grad_buffer:
grad_buffer_size = _round_to_multiple(
grad_buffer_size,
self.bucket_size,
)
self._grad_buffer = torch.zeros(
[grad_buffer_size],
dtype=self.dtype,
device=self.device,
)
def init_params(self, params=None):
"""Initialize optimizer state for parameters
Arguments:
params (iterable, optional): parameters to initialize
(default: all parameters)
"""
# Default cases
if isinstance(params, torch.Tensor):
params = [params]
elif params is None:
params = []
for group in self.param_groups:
params.extend(group['params'])
# Get indices corresponding to parameters
id_map = dict()
for param_group_id, group in enumerate(self.param_groups):
for param_id, param in enumerate(group['params']):
id_map[param] = (param_group_id, param_id)
# Initialize parameters
for param in params:
if param in id_map and 'fragments' not in self.state[param]:
param_group_id, param_id = id_map[param]
self._init_param_state(param, param_group_id, param_id)
def _init_param_state(
self,
param,
param_group_id,
param_id,
):
"""Initialize optimizer state for a parameter"""
# Make sure there is at least one bucket
if not self.state['buckets']:
self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
# Split parameter values into fragments
# Note: Each fragment resides within a bucket
param_start = 0
param_size = param.numel()
self.state[param]['fragments'] = []
while param_start < param_size:
# Get current bucket
bucket_id = len(self.state['buckets']) - 1
bucket = self.state['buckets'][bucket_id]
fragment_id = len(bucket.fragments)
# Determine fragment position within bucket
if fragment_id == 0:
bucket_start = 0
else:
_, bucket_start = bucket.fragments[-1].bucket_range
bucket_start = _round_to_multiple(bucket_start, self.alignment)
fragment_size = min(param_size-param_start, self.bucket_size-bucket_start)
param_end = param_start + fragment_size
bucket_end = bucket_start + fragment_size
# Create new bucket if current one is full
if fragment_size <= 0:
self.state['buckets'].append(
self.StateBucket(self.shard_size, self.dtype, self.device)
)
continue
# Fragment position within local shard
shard_id = self.distributed_rank
shard_start = bucket_start - self.shard_size*shard_id
shard_end = bucket_end - self.shard_size*shard_id
shard_start = min(max(shard_start, 0), self.shard_size)
shard_end = min(max(shard_end, 0), self.shard_size)
in_local_shard = shard_start < shard_end
if in_local_shard:
shard_bucket_start = shard_start + self.shard_size*shard_id
shard_bucket_end = shard_bucket_start + shard_end - shard_start
shard_param_start = shard_bucket_start - bucket_start + param_start
shard_param_end = shard_param_start + shard_end - shard_start
else:
shard_bucket_start, shard_bucket_end = None, None
shard_param_start, shard_param_end = None, None
# Record fragment info
fragment = self.ParameterFragment(
param_group_id=param_group_id,
param_id=param_id,
bucket_id=bucket_id,
param_range=(param_start,param_end),
bucket_range=(bucket_start,bucket_end),
in_local_shard=in_local_shard,
shard_range=(shard_start,shard_end),
shard_bucket_range=(shard_bucket_start,shard_bucket_end),
shard_param_range=(shard_param_start,shard_param_end),
)
self.state[param]['fragments'].append(fragment)
bucket.fragments.append(fragment)
param_start = param_end
# Initialize master param buffer
for fragment in self.state[param]['fragments']:
if fragment.in_local_shard:
bucket = self.state['buckets'][fragment.bucket_id]
param_start, param_end = fragment.shard_param_range
shard_start, shard_end = fragment.shard_range
model_param_fragment = param.view(-1)[param_start:param_end]
master_param_fragment = bucket.params_shard[shard_start:shard_end]
master_param_fragment.copy_(model_param_fragment)
def zero_grad(self, set_to_none=True):
"""Clear parameter gradients"""
# Reset bucket buffers
self._grads_buckets.clear()
# Construct views into contiguous grad buffer, if needed
if self.contiguous_grad_buffer:
self._grad_buffer.zero_()
for bucket_id in range(len(self.state['buckets'])):
bucket_start = bucket_id * self.bucket_size
bucket_end = bucket_start + self.bucket_size
bucket = self._grads_buckets[bucket_id]
bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end]
# Reset param grads
for group in self.param_groups: for group in self.param_groups:
self._param_group = group for param in group['params']:
prev = None if param.grad is None or set_to_none:
beta1, beta2 = group['betas'] param.grad = None
bias_correction = 1 if group['bias_correction'] else 0 else:
eps = group['eps'] param.grad.zero_()
weight_decay = group['weight_decay']
for p in group['params']: # Reset other state
# broadcast from rank 0 of current process group self._grads_generated = set()
torch.distributed.broadcast(p, src=self._available_ranks[0], group=self._current_process_group) self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device)
if not p.requires_grad: self._grad_norm = None
continue
self._model_params.append(p) def _grad_copy(self, param):
# Multiple param groups support: """Copy parameter gradients to buckets"""
# store one hyperparam item per parameter tensor
self._group_properties.append(( # Copy param grad to buckets
beta1, for fragment in self.state[param]['fragments']:
beta2,
bias_correction, # Get fragment position
eps, bucket_id = fragment.bucket_id
weight_decay bucket = self._grads_buckets[bucket_id]
)) grad_start, grad_end = fragment.param_range
p_grads_size = p.numel() bucket_start, bucket_end = fragment.bucket_range
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param) # Set reduction status
grad_acc = param_tmp.grad_fn.next_functions[0][0] if bucket.status == self.GradientStatus.SYNCING:
def allreduce_hook(*unused): self._finish_bucket_grad_sync()
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param) bucket.status = self.GradientStatus.PARTIALLY_FILLED
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc) # Allocate gradient buffer if needed
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset}) if bucket.grads_bucket is None:
wrapper(p, p_i, p_grads_size, p_offset) if self.contiguous_grad_buffer:
p_offset += p_grads_size grad_buffer_start = bucket_id * self.bucket_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters grad_buffer_end = grad_buffer_start + self.bucket_size
# RNN is one example of consecutive parameters: bucket.grads_bucket = self._grad_buffer[grad_buffer_start:grad_buffer_end]
# (weight_ih, weight_hh, bias_ih, bias_hh) else:
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()): bucket.grads_bucket = torch.empty(
p_offset = ((p_offset + 63) // 64) * 64 [self.bucket_size],
prev = p dtype=self.grad_sync_dtype,
p_i += 1 device=self.device,
self._grads_generated = [False]*len(self._grads_info) )
self._grads = [] bucket.grads_bucket.zero_()
if self._overlap_reductions:
self._current_block = self._num_blocks # Copy param grad to bucket
if param.grad is not None:
self._net_total_param_size = p_offset grad_in = param.grad.detach().view(-1)[grad_start:grad_end]
self._total_param_size = p_offset grad_out = bucket.grads_bucket[bucket_start:bucket_end]
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size if grad_in.data_ptr() != grad_out.data_ptr():
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size grad_out.add_(grad_in)
self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks # Free param grad buffer
self._shard_size = self._chunk_size // self._group_size param.grad = None
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
def grad_buffer_view(self, param):
self._low_param_i = [0]*self._num_blocks """Construct view into grad buffer corresponding to param
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1 Assumes optimizer is using a contiguous grad buffer.
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1 """
self._low_param_i[block_id] = p_i assert self.contiguous_grad_buffer
#print(self._low_param_i)
# Figure out corresponding position in grad buffer
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') param_fragments = self.state[param]['fragments']
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') start_bucket_id = param_fragments[0].bucket_id
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size start_bucket_offset, _ = param_fragments[0].bucket_range
# initialize master weights, moments buffers if not loaded from checkpoint end_bucket_id = param_fragments[-1].bucket_id
if self._fp32_p is None: _, end_bucket_offset = param_fragments[-1].bucket_range
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') buffer_start = start_bucket_id * self.bucket_size + start_bucket_offset
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') buffer_end = end_bucket_id * self.bucket_size + end_bucket_offset
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16 # Construct view into grad buffer
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') flat_buffer = self._grad_buffer[buffer_start:buffer_end]
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda') return flat_buffer.detach().view(param.size())
self._individual_flat_grads = [] def _force_bucket_grad_sync(self):
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)): """Ensure that all gradient buckets are synchronized"""
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
# Synchronize all unsynchronized buckets
def _flat_split(p): self._finish_bucket_grad_sync()
def __blockify(p): buckets = [
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)] bucket
def __chunkify(p): for bucket_id, bucket in sorted(self._grads_buckets.items())
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)] if bucket.status != self.GradientStatus.READY
def __shardify(p): ]
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)] if buckets:
list_of_blocks = __blockify(self._flat_grads) self._start_bucket_grad_sync(buckets)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks] self._finish_bucket_grad_sync()
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards # Fill any unsynchronized gradients with zeros
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads) for bucket_id in range(len(self.state['buckets'])):
def _full_packed_split(p): bucket = self._grads_buckets[bucket_id]
def __shardify(p): if bucket.grads_shard is None:
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)] bucket.grads_shard = torch.zeros(
def __blockify(p): [self.shard_size],
return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)] dtype=self.grad_sync_dtype,
def __chunkify(p): device=self.device,
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)] )
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards] # Reset set of generated gradients
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks] self._grads_generated = set()
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params) def _try_start_bucket_grad_sync(
def _packed_split(p): self,
def __packed_blockify(p): params=[],
packed_block_size = self._num_chunks*self._shard_size ignore_last_bucket=True,
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)] ):
def __packed_chunkify(p): """Launches gradient synchronization if enough buckets are ready
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)] Gradient synchronization is asynchronous. Launches gradient
list_of_blocks = __packed_blockify(p) synchronization if all gradients have been generated or if
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks] there are enough buckets ready to fill pipeline.
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p) Arguments:
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m) params (iterable): parameters that have had their
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v) gradients copied to buckets
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p) ignore_last_bucket (bool): avoid synchronizing last bucket
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g) until all gradients have been generated. This avoids
excessive synchronization when initializing buckets in
# This paragraph does two things: the first backward pass.
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather """
self._packed_flat_to_model_params = []
self._contrib_tensor_list = [] # Register params that have generated grads
self._contrib_group_properties = [] for param in params:
self._non_parallel_grads = [] self._grads_generated.add(param)
for shard_id in range(self._group_size): for fragment in self.state[param]['fragments']:
for block_id in range(self._num_blocks): bucket_id = fragment.bucket_id
for chunk_id in range(self._num_chunks): bucket_fragments = self.state['buckets'][bucket_id].fragments
flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size is_filled = True
flat_shard_end = flat_shard_start + self._shard_size for other_fragment in reversed(bucket_fragments):
for (p, grads_info, group_props) in zip(self._model_params, self._grads_info, self._group_properties): param_group_id = other_fragment.param_group_id
flat_grad_start = grads_info["param_offset"] param_id = other_fragment.param_id
flat_grad_end = flat_grad_start + grads_info["param_grads_size"] other_param = self.param_groups[param_group_id]['params'][param_id]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start) if other_param not in self._grads_generated:
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end) is_filled = False
if clipped_start < clipped_end: break
grad_offset = clipped_start - flat_grad_start if is_filled:
grad_length = clipped_end - clipped_start bucket = self._grads_buckets[bucket_id]
shard_offset = clipped_start - flat_shard_start bucket.status = self.GradientStatus.FULLY_FILLED
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length] # Launch reductions if enough buckets are ready
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) ) if len(self._grads_generated) == self._num_grads:
if shard_id == self._group_rank: self._force_bucket_grad_sync()
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
if not self._resume_from_checkpoint:
master_param_fragment.copy_(model_param_fragment)
self._contrib_group_properties.append(group_props)
self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, g, p_copy
if self._model_parallel and hasattr(p, 'model_parallel') and not p.model_parallel:
self._non_parallel_grads.append(opti_state_g_fragment)
p, m, v, g, p_copy = list(zip(*self._contrib_tensor_list))
self._contrib_tensor_list = [p, m, v, g, p_copy]
math_type = self._fp32_p.dtype
beta1, beta2, bias_correction, epsilon, decay = list(zip(*self._contrib_group_properties))
self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
if self._num_groups > 1:
self._ar_pg = []
for i in range(self._num_process_groups):
# gather global ranks of all members of the current process group
ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for j in range(self._group_size):
ar_idx = [j+k*self._group_size for k in range(self._num_groups)]
ar_rank = [ranks[k] for k in ar_idx]
#if self._global_rank in ar_rank:
# print("group for all reduce, ranks:", ar_rank)
for _ in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ar_rank)
if self._global_rank in ar_rank:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
self._rs_pg, rs_ranks = [],[]
for i in range(self._num_process_groups):
ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)]
for j in range(self._num_groups):
rs_idx = [j*self._group_size+k for k in range(self._group_size)]
rs_rank = [ranks[k] for k in rs_idx]
#if self._global_rank in rs_rank:
# print("group for reduce scatter, ranks:", rs_rank)
for _ in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=rs_rank)
if self._global_rank in rs_rank:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm:
l2_grad_norm_pg = torch.distributed.new_group(ranks=rs_rank)
if self._global_rank in rs_rank:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else: else:
self._ag_pg = [] filled_buckets = []
for i in range(self._num_process_groups): for bucket_id, bucket in sorted(self._grads_buckets.items()):
ranks = [i+k*self._num_process_groups for k in range(self._process_group_size)] if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1:
for j in range(self._num_groups): continue
ag_rank = rs_ranks[j] if bucket.status == self.GradientStatus.FULLY_FILLED:
#if self._global_rank in ag_rank: filled_buckets.append(bucket)
# print("group for all gather, ranks:", ag_rank) pipeline_size = _round_to_multiple(
for _ in range(self._num_ag_pg): len(filled_buckets),
grp = torch.distributed.new_group(ranks=ag_rank) self.pipeline_size,
if self._global_rank in ag_rank: )
self._ag_pg.append(grp) if pipeline_size > 0:
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] self._start_bucket_grad_sync(filled_buckets[:pipeline_size])
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg) def _start_bucket_grad_sync(self, buckets):
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None """Synchronize gradient buckets
self._completion_st = torch.cuda.Stream()
Gradient synchronization is asynchronous. Involves
self._reductions_works = [None]*self._num_blocks reduce-scatter over distributed process group and allreduce
self._allgather_works = [None]*self._num_blocks over redundant process group.
import inspect """
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
# Call recursively if more buckets than streams
def _init_everything(self): while len(buckets) > self.pipeline_size:
if not self._init_done: self._start_bucket_grad_sync(buckets[:self.pipeline_size])
self._first_step_init() buckets = buckets[self.pipeline_size:]
self._init_done = True self._finish_bucket_grad_sync()
def set_last_step(self, last_step): # Reduction operation
self._last_step = last_step if self.average_grad_sync:
reduce_op = torch.distributed.ReduceOp.AVG
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Optionally compute L2 grad norm
if self._compute_L2_grad_norm and block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
# for model_parallel_rank=0, keep all gradients
# for the rest, subtract non_parallel gradients
if self._model_parallel and self._process_group_id: # non zero model_parallel_rank
non_parallel_grad_norm_sq = torch.zeros([1], device='cuda')
if len(self._non_parallel_grads): # non parallel grads exit
non_parallel_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm,
self._overflow_buf,
[self._non_parallel_grads], False)[0]**2
torch.distributed.all_reduce(non_parallel_grad_norm_sq, group=self._l2_grad_norm_pg)
l2_grad_norm_sq = l2_grad_norm_sq - non_parallel_grad_norm_sq
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale = self._global_scale
if self._clip_grad_norm and self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
self._step += 1
multi_tensor_applier(distributed_adam_cuda.multi_tensor_fused_adam,
self._overflow_buf,
self._contrib_tensor_list, # p, m, v, g, p_copy
self._contrib_beta1,
self._contrib_beta2,
self._contrib_bias_correction,
self._contrib_epsilon,
self._contrib_weight_decay,
self._param_group['lr'],
combined_scale,
self._step,
self.eps_mode)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel()
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else: else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i]) reduce_op = torch.distributed.ReduceOp.SUM
self._grads_generated[param_i]=True
if not self._last_step: # Reduce gradients
if self._overlap_reductions: main_stream = torch.cuda.current_stream()
flush_block = self._get_flush_block() for stream in self._pipeline_streams:
while flush_block: stream.wait_stream(main_stream)
block_id = flush_block[0] // self._block_size for i, bucket in enumerate(buckets):
self._pipeline_block_reductions(block_id) bucket.status = self.GradientStatus.SYNCING
flush_block = self._get_flush_block() stream = self._pipeline_streams[i % self.pipeline_size]
with torch.cuda.stream(stream):
def set_global_scale(self, global_scale):
"""Set global scale. # Reduce-scatter over distributed process group
bucket.sync_wait()
if self.distributed_size == 1:
bucket.sync_grads_shard = bucket.grads_bucket
else:
with torch.cuda.stream(main_stream):
bucket.sync_grads_shard = torch.zeros(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
grads_bucket_shards = [
bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size)
]
if self._reduce_scatter_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
bucket.sync_request = (
torch.distributed.reduce_scatter(
bucket.sync_grads_shard,
grads_bucket_shards,
op=reduce_op,
group=self.distributed_process_group,
async_op=True,
**no_copy_kwarg,
)
)
# All-reduce over redundant process group
# Note: Assuming reduce-scatters are finished in the
# order they are submitted, all-reduces should be
# submitted in a consistent order. There could be race
# conditions if wait doesn't finish in order.
if self.redundant_size > 1:
bucket.sync_wait()
bucket.sync_request = (
torch.distributed.all_reduce(
bucket.sync_grads_shard,
op=reduce_op,
group=self.redundant_process_group,
async_op=True,
)
)
def _finish_bucket_grad_sync(self):
"""Wait for any gradient synchronizations that are in progress"""
for bucket_id, bucket in sorted(self._grads_buckets.items()):
if bucket.status == self.GradientStatus.SYNCING:
# Finish asynchronous communication
bucket.sync_wait()
# Accumulate gradient in local shard
if bucket.grads_shard is None:
bucket.grads_shard = bucket.sync_grads_shard
else:
bucket.grads_shard.add_(bucket.sync_grads_shard)
bucket.grads_bucket = None
bucket.sync_grads_shard = None
# Reset status
bucket.status = self.GradientStatus.READY
# Cached gradient norm has been invalidated
self._grad_norm = None
@contextlib.contextmanager
def no_sync(self, greedy_grad_copy=False):
"""Disable overlapped gradient synchronization
Context manager that is similar to
torch.nn.parallel.DistributedDataParallel.no_sync. The
gradients can be synchronized by calling grad_sync or step. If
overlapped gradient synchronization is enabled, gradients can
also be synchronized by leaving the context and performing a
backward pass.
Arguments:
greedy_grad_copy (bool, optional): copy parameter
gradients to buckets as soon as they are generated
(default: False)
""" """
self._global_scale = global_scale old_greedy_grad_copy = self.greedy_grad_copy
old_overlap_grad_sync = self.overlap_grad_sync
self.greedy_grad_copy = greedy_grad_copy
self.overlap_grad_sync = False
try:
yield
finally:
self.greedy_grad_copy = old_greedy_grad_copy
self.overlap_grad_sync = old_overlap_grad_sync
def grad_sync(self):
"""Ensure that all gradients are synchronized"""
for bucket in self.state['buckets']:
for fragment in bucket.fragments:
param_group_id = fragment.param_group_id
param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id]
if param.grad is not None:
self._grad_copy(param)
self._try_start_bucket_grad_sync(
params=[param],
ignore_last_bucket=False,
)
self._force_bucket_grad_sync()
def _local_grad_norm(self, parameters=[], norm_type=2.0):
"""Local contribution to parameter gradient norm
Returns square of 2-norm. Other norms are not yet supported.
@property If no parameters are provided, the norm is computed for all
def global_scale(self): parameters in optimizer. Provided parameters are assumed to be
return self._global_scale in optimizer.
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
""" """
has_overflow = self._has_overflow norm_type = float(norm_type)
self._has_overflow = False assert norm_type == 2.0
return has_overflow
# Make sure that gradients have been reduced
@property self.grad_sync()
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method. if not parameters or len(parameters) == self._num_grads:
Does not clear overflow flag. # Compute norm of all local gradients
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[[bucket.grads_shard for bucket in self._grads_buckets.values()]],
False,
)[0] ** 2
else:
# Compute norm of selected local gradients
grads = []
for param in parameters:
for fragment in self.state[param]['fragments']:
if fragment.in_local_shard:
bucket = self._grads_buckets[fragment.bucket_id]
shard_start, shard_end = fragment.shard_range
grads.append(bucket.grads_shard[shard_start:shard_end])
if grads:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
grad_norm_sq = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False,
)[0] ** 2
else:
grad_norm_sq = torch.zeros([1], dtype=torch.float32, device=self.device)
return grad_norm_sq.detach().view([])
def grad_norm(self, parameters=[], norm_type=2.0, force=False):
"""Gradient norm of parameters in optimizer
The norm is computed over all gradients together, as if they
were concatenated into a single vector. All provided
parameters must be managed by optimizer.
The computed value is cached to avoid redundant communication.
Arguments:
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2). Only 2-norm is currently
supported.
force (bool, optional): ignore cached value and force norm
computation (default: False).
""" """
return self._has_overflow if force or self._grad_norm is None:
norm_type = float(norm_type)
assert norm_type == 2.0
grad_norm_sq = self._local_grad_norm(
parameters=parameters,
norm_type=norm_type,
)
torch.distributed.all_reduce(
grad_norm_sq,
op=torch.distributed.ReduceOp.SUM,
group=self.distributed_process_group,
)
self._grad_norm = grad_norm_sq.sqrt()
return self._grad_norm.detach()
def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0):
"""Clips gradient norm of parameters in optimizer
The norm is computed over all gradients together, as if they
were concatenated into a single vector. The scaling is
deferred until the optimizer step, which should be called
immediately after this function.
The computed grad norm is cached to avoid redundant
communication.
Arguments:
max_norm (float or int): max norm of the gradients
parameters (iterable, optional): an iterable of parameters
in optimizer (default: all parameters in optimizer).
norm_type (float or int, optional): type of the used
p-norm (default: 2)
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
""" """
if start >= 0 and start < end: assert max_norm > 0
out_p = output_params[start:end] total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type)
else: inv_clip_coef = (total_norm + 1e-6) / max_norm
out_p = output_params self._inv_grad_scale = torch.clamp(inv_clip_coef, min=1.0).view(1)
fused_adam_cuda.strided_check_finite(self._overflow_buf, return total_norm
out_p,
stride, def step(self, closure=None, *, grad_scaler=None):
1 if clear else 0) """Apply Adam optimizer step
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow Arguments:
closure (callable, optional): closure to recompute loss
@property (default: None)
def L2_grad_norm(self): grad_scaler (torch.cuda.amp.GradScaler, optional):
if self._compute_L2_grad_norm: gradient scaler (default: None)
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else:
return None
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
""" """
self._init_everything()
if self._last_step: # Apply closure
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None):
loss = None loss = None
if closure is not None: if closure is not None:
loss = closure() loss = closure()
self._pipeline_step() # Make sure that gradients have been reduced
self.grad_sync()
# Apply gradient scaler if provided
# Note: We compute gradient norm to check for non-finite
# values. This is more conservative and compute intensive than
# directly checking, but it avoids extra communication if we
# have already computed gradient norm e.g. for gradient
# clipping.
if grad_scaler is not None:
grad_norm = self.grad_norm()
found_inf = torch.logical_not(torch.isfinite(grad_norm))
scaler_state = grad_scaler._per_optimizer_states[id(self)]
scaler_state['found_inf_per_device'] = {found_inf.device: found_inf.float()}
if found_inf.item():
return
else:
assert grad_scaler._scale is not None
self._inv_grad_scale *= grad_scaler._scale
inv_grad_scale = self._inv_grad_scale.item()
# Construct workspace buffers
params_bucket_buffers = [
torch.empty(
[self.bucket_size],
dtype=self.param_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
if self.grad_sync_dtype == self.param_sync_dtype:
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_copy_buffers = [
params_bucket[shard_start:shard_end]
for params_bucket in params_bucket_buffers
]
else:
params_copy_buffers = [
torch.empty(
[self.shard_size],
dtype=self.grad_sync_dtype,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Apply optimizer step to each bucket and synchronize params
self.state['step'] += 1
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(main_stream)
for bucket_id in range(len(self.state['buckets'])):
stream_id = bucket_id % self.pipeline_size
# Bucket buffers
fragments = self.state['buckets'][bucket_id].fragments
shard_start = self.distributed_rank * self.shard_size
shard_end = shard_start + self.shard_size
params_bucket = params_bucket_buffers[stream_id]
params_bucket_shard = params_bucket[shard_start:shard_end]
params_shard = self.state['buckets'][bucket_id].params_shard
params_copy = params_copy_buffers[stream_id]
exp_avg = self.state['buckets'][bucket_id].exp_avg_shard
exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard
grads = self._grads_buckets[bucket_id].grads_shard
# Perform compute on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream):
# Find param fragments in local shard
buffers = collections.defaultdict(list) # p, m, v, g, p_copy
for fragment in fragments:
if fragment.in_local_shard:
param_group_id = fragment.param_group_id
shard_start, shard_end = fragment.shard_range
buffers[param_group_id].append([
params_shard[shard_start:shard_end],
exp_avg[shard_start:shard_end],
exp_avg_sq[shard_start:shard_end],
grads[shard_start:shard_end],
params_copy[shard_start:shard_end],
])
with torch.cuda.stream(self._completion_st): # Fuse param fragments if possible
# Copy self._new_params to model params if len(buffers) == 1:
multi_tensor_applier( group_id = list(buffers.keys())[0]
fused_adam_cuda.maybe_cast_mt, buffers[group_id] = [(
self._overflow_buf, params_shard,
self._packed_flat_to_model_params) exp_avg,
exp_avg_sq,
grads,
params_copy,
)]
torch.cuda.current_stream().wait_stream(self._completion_st) # Apply optimizer step to each param group
for group_id, group_buffers in buffers.items():
self._reductions_works = [None]*self._num_blocks # Get param group configs
self._allgather_works = [None]*self._num_blocks group = self.param_groups[group_id]
beta1, beta2 = group['betas']
bias_correction = 1 if group['bias_correction'] else 0
eps = group['eps']
weight_decay = group['weight_decay']
# Copy param group configs to GPU
num_fragments = len(group_buffers)
beta1 = torch.full([num_fragments], beta1, dtype=self.dtype, device='cuda')
beta2 = torch.full([num_fragments], beta2, dtype=self.dtype, device='cuda')
bias_correction = torch.full([num_fragments], bias_correction, dtype=torch.int32, device='cuda')
eps = torch.full([num_fragments], eps, dtype=self.dtype, device='cuda')
weight_decay = torch.full([num_fragments], weight_decay, dtype=self.dtype, device='cuda')
# Apply Adam step
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
distributed_adam_cuda.multi_tensor_fused_adam,
dummy_overflow_buf,
list(zip(*group_buffers)),
beta1,
beta2,
bias_correction,
eps,
weight_decay,
group['lr'],
inv_grad_scale,
self.state['step'],
1, # Set to 0 to apply eps inside sqrt
)
# Cast parameter dtype if needed
if params_copy.data_ptr() != params_bucket_shard.data_ptr():
params_bucket_shard.copy_(params_copy)
# Allgather updated parameters
if self.distributed_size > 1:
all_params_bucket_shards = [
params_bucket[i*self.shard_size:(i+1)*self.shard_size]
for i in range(self.distributed_size)
]
if self._all_gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.all_gather(
all_params_bucket_shards,
params_bucket_shard,
group=self.distributed_process_group,
**no_copy_kwarg,
)
# Copy values to param buffers
buffers = collections.defaultdict(list) # param_in, param_out
for fragment in fragments:
param_group_id = fragment.param_group_id
param_id = fragment.param_id
param = self.param_groups[param_group_id]['params'][param_id]
bucket_start, bucket_end = fragment.bucket_range
param_start, param_end = fragment.param_range
param_in = params_bucket[bucket_start:bucket_end]
param_out = param.detach().view(-1)[param_start:param_end]
if param_in.dtype == param_out.dtype:
# Just copy bytes if buffers have same type
param_in = param_in.view(torch.uint8)
param_out = param_out.view(torch.uint8)
buffers[(param.is_cuda, param.dtype)].append(
(param_in, param_out)
)
for (is_cuda, dtype), dtype_buffers in buffers.items():
fused_kernel_dtypes = (
self.param_sync_dtype,
torch.float32,
torch.float16,
torch.uint8,
)
if is_cuda and dtype in fused_kernel_dtypes:
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda')
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
dummy_overflow_buf,
list(zip(*dtype_buffers)),
)
else:
for param_in, param_out in dtype_buffers:
param_out.copy_(param_in)
# Synchronize pipeline streams
for stream in self._pipeline_streams:
main_stream.wait_stream(stream)
return loss return loss
def state_dict(self): def state_dict(self, gather_on_root=True):
""" """Get dictionary containing optimizer state
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example:: Default behavior is to perform communication so that the
checkpoint = {} entire optimizer state is returned on the root rank in the
checkpoint['model'] = model.state_dict() process group. In this case, all ranks in the process group
checkpoint['optimizer'] = optimizer.state_dict() must enter this function and no value is returned on non-root
torch.save(checkpoint, "saved.pth") ranks.
Arguments:
gather_on_root (bool, optional): Gather state from all
ranks on the root rank (default: True)
""" """
# save step, master weights and first/second moments state_dict = super().state_dict()
state_dict = {} if not gather_on_root:
state_dict['step'] = self._step return state_dict
state_dict['fp32_p'] = self._fp32_p
state_dict['fp32_m'] = self._fp32_m # Export local state to byte string
state_dict['fp32_v'] = self._fp32_v state_bytes = io.BytesIO()
return state_dict torch.save(state_dict, state_bytes)
state_bytes.seek(0)
state_bytes_view = state_bytes.getbuffer()
# Get data sizes on all ranks
local_state_size = len(state_bytes_view)
state_sizes = [None] * self.distributed_size
torch.distributed.all_gather_object(
state_sizes,
local_state_size,
group=self.process_group,
)
max_state_size = max(state_sizes)
# Construct workspace buffers
chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8
if self.distributed_rank == 0:
gathered_state_bytes = [state_bytes.getvalue()]
gathered_state_bytes.extend(bytearray(size) for size in state_sizes[1:])
gathered_chunks_buffers = [
torch.empty(
[chunk_size * self.distributed_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
else:
chunk_buffers = [
torch.empty(
[chunk_size],
dtype=torch.uint8,
device=self.device,
)
for _ in range(self.pipeline_size)
]
# Split data into chunks and gather on root rank
# Note: Assuming we are using the NCCL backend, communication
# must happen on the GPU. We split the data into fixed-size
# chunks so that the GPU memory usage is limited to
# (chunk_size * distributed_size) bytes.
# TODO: Avoid chunking with direct communication between CPUs
main_stream = torch.cuda.current_stream()
for stream in self._pipeline_streams:
stream.wait_stream(main_stream)
for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)):
stream_id %= self.pipeline_size
# Buffers for chunk
if self.distributed_rank == 0:
gathered_chunks = [
gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size]
for i in range(self.distributed_size)
]
else:
chunk = chunk_buffers[stream_id]
# Perform communication on parallel stream
stream = self._pipeline_streams[stream_id]
with torch.cuda.stream(stream):
# Copy to GPU
if self.distributed_rank != 0 and offset < local_state_size:
local_chunk_size = min(chunk_size, local_state_size-offset)
chunk[:local_chunk_size].copy_(
torch.frombuffer(
state_bytes_view,
dtype=torch.uint8,
count=local_chunk_size,
offset=offset,
),
non_blocking=True,
)
# Gather on root
if self.distributed_rank == 0:
if self._gather_no_copy:
no_copy_kwarg = { 'no_copy': True }
else:
no_copy_kwarg = {}
torch.distributed.gather(
gathered_chunks[0],
gathered_chunks,
dst=self._process_group_ranks[0],
group=self.process_group,
**no_copy_kwarg,
)
else:
torch.distributed.gather(
chunk,
dst=self._process_group_ranks[0],
group=self.process_group,
)
# Copy back to CPU
if self.distributed_rank == 0:
for rank in range(1, self.distributed_size):
if offset < state_sizes[rank]:
rank_chunk_size = min(chunk_size, state_sizes[rank]-offset)
torch.frombuffer(
gathered_state_bytes[rank],
dtype=torch.uint8,
count=rank_chunk_size,
offset=offset,
).copy_(
gathered_chunks[rank][:rank_chunk_size],
non_blocking=True,
)
# Synchronize GPU
for stream in self._pipeline_streams:
main_stream.wait_stream(stream)
main_stream.synchronize()
# Return gathered state data on root rank
if self.distributed_rank == 0:
return {'gathered_states': gathered_state_bytes}
else:
return None
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """Load optimizer state"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``, # State dict contains state for all ranks
whose parameters in turn came from ``model``, it is expected that the user if 'gathered_states' in state_dict:
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called. # Deallocate distributed optimizer state to reduce GPU
Example:: # memory usage
model = torch.nn.Linear(D_in, D_out).cuda().half() if 'buckets' in self.state:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) del self.state['buckets']
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
... # Get state for current rank and parse byte string
checkpoint = torch.load("saved.pth") state_bytes = state_dict['gathered_states'][self.distributed_rank]
model.load_state_dict(checkpoint['model']) state_bytes = io.BytesIO(state_bytes)
optimizer.load_state_dict(checkpoint['optimizer']) state_dict = torch.load(state_bytes)
"""
# restore step, master weights and first/second moments return super().load_state_dict(state_dict)
self._step = state_dict['step']
self._fp32_p = state_dict['fp32_p'].to(device="cuda")
self._fp32_m = state_dict['fp32_m'].to(device="cuda")
self._fp32_v = state_dict['fp32_v'].to(device="cuda")
self._resume_from_checkpoint = True
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV2(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV2, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
self._has_overflow = False
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._num_chunks = dwu_num_chunks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._compute_L2_grad_norm = compute_L2_grad_norm
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
if self._overlap_reductions:
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._block_size // self._group_size
self._chunk_size = self._shard_size // self._num_chunks
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d, self._chunk_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size,self._chunk_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._chunk_size
self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
self._individual_flat_grads = []
for p_i, (grads_info, p) in enumerate(zip(self._grads_info, self._model_params)):
self._individual_flat_grads.append(self._flat_grads[grads_info["param_offset"]:grads_info["param_offset"]+grads_info["param_grads_size"]].view_as(p))
def _flat_split(p):
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._group_size)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_shards = [__shardify(block) for block in list_of_blocks]
list_of_list_of_list_of_chunks = [[__chunkify(shard) for shard in shards] for shards in list_of_list_of_shards]
return list_of_blocks, list_of_list_of_shards, list_of_list_of_list_of_chunks
self._flat_grads_blocks, self._flat_grads_shards, self._flat_grads_chunks = _flat_split(self._flat_grads)
def _full_packed_split(p):
def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
def __blockify(p):
return [p[block_id*self._num_chunks*self._chunk_size:(block_id+1)*self._num_chunks*self._chunk_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_mega_shards = __shardify(p)
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p):
def __packed_blockify(p):
packed_block_size = self._num_chunks*self._chunk_size
return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
def __packed_chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
# current arrangement
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._shard_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._new_params_mega_chunks [x self._num_chunks, self._shard_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._shard_size]
# self._fp32_p_chunks [x self._num_chunks, self._shard_size]
# each chunk contains one shard
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# for chunk_id in range(self._num_chunks):
# works[chunk_id] = torch.distributed.reduce_scatter(self._flat_grads_chunks[block_id][chunk_id], self._fp16_g_chunks[block_id][chunk_id], ...)
#
# ----------------------------------------------------------------------------------------
#
# new arrangement
#
# NB! New equations for self._shard_size and self._chunk_size
#
# self._flat_grads
# self._flat_grads_blocks [x self._num_blocks, self._block_size]
# self._flat_grads_shards [x self._group_size, self._shard_size]
# self._flat_grads_chunks [x self._num_chunks, self._chunk_size]
#
# self._new_params
# self._new_params_mega_shards [x self._group_size, self._num_blocks*self._num_chunks*self._chunk_size]
# self._new_params_mega_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._new_params_mega_chunks [x self._num_chunks, self._chunk_size]
#
# self._fp32_p
# self._fp32_p_blocks [x self._num_blocks, self._num_chunks*self._chunk_size]
# self._fp32_p_chunks [x self._num_chunks, self._chunk_size]
# same for self._fp32_m, self._fp32_v, self._fp16_p and self._fp16_g
#
# Usage:
#
# work = torch.distributed.reduce_scatter(self._flat_grads_blocks[block_id], self._fp16_g[block_id], ...)
# for chunk_id in range(self._num_chunks):
# work.wait()
# works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id], ...)
# or
# work.wait()
# works[0] = torch.distributed.all_reduce(self._fp16_g_blocks[block_id], ...)
#
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self._packed_flat_to_model_params = []
for shard_id in range(self._group_size):
for block_id in range(self._num_blocks):
flat_shard_start = (block_id * self._group_size + shard_id) * self._shard_size
flat_shard_end = flat_shard_start + self._shard_size
for p, grads_info in zip(self._model_params, self._grads_info):
flat_grad_start = grads_info["param_offset"]
flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
if clipped_start < clipped_end:
grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_blocks[shard_id][block_id][shard_offset:shard_offset+grad_length]
self._packed_flat_to_model_params.append( (new_param_packed_fragment, model_param_fragment) )
if shard_id == self._rank_in_group:
# copy model parameters into master buffer
master_param_fragment = self._fp32_p_blocks[block_id][shard_offset:shard_offset+grad_length]
print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment.copy_(model_param_fragment)
p_in, p_out = zip(*self._packed_flat_to_model_params)
self._packed_flat_to_model_params = [p_in, p_out]
self._distributed_weight_update = distributed_weight_update # Is this still needed?
self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg
self._num_ag_pg = dwu_num_ag_pg
if self._num_groups > 1:
self._ar_pg = []
for dev_i in range(self._group_size):
ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
for i in range(self._num_ar_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = []
for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
self._rs_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_rs_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg
self._ag_st = self._rs_st
self._num_ag_pg = self._num_rs_pg
else:
self._ag_pg = []
for group_i in range(self._num_groups):
ranks = rs_ranks[group_i]
for i in range(self._num_ag_pg):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream()
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def _pipeline_block_reductions(self, block_id):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
rs_stream = self._rs_st[block_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream):
rs_work = torch.distributed.reduce_scatter(self._fp16_g_blocks[block_id],self._flat_grads_shards[block_id],group=self._rs_pg[block_id%self._num_rs_pg],async_op=True,no_copy=True)
for chunk_id in range(self._num_chunks):
works[chunk_id] = rs_work
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
rs_work.wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
# Optionally compute L2 grad norm
if self._compute_L2_grad_norm and block_id == 0:
with torch.cuda.stream(self._l2_grad_norm_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq = torch.empty([1], device='cuda')
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt().item()
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _pipeline_block_step(self, block_id):
# Call step kernel once per block
ag_stream = self._ag_st[block_id%self._num_ag_pg]
with torch.cuda.stream(ag_stream):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p_blocks[block_id],
self._fp16_p_blocks[block_id],
self._fp32_m_blocks[block_id],
self._fp32_v_blocks[block_id],
self._fp16_g_blocks[block_id])
# Call all-gather once per step.
# FIXME: Determine which is faster, one all-gather per block or a single all-gather at end
if block_id == 0:
for other_ag_stream in self._ag_st:
self._completion_st.wait_stream(other_ag_stream)
with torch.cuda.stream(self._completion_st):
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _pipeline_step(self):
# Call step kernel once per step
# Call all-gather once per step
with torch.cuda.stream(self._completion_st):
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
self.__launch_step_kernel(
self._fp32_p,
self._fp16_p,
self._fp32_m,
self._fp32_v,
self._fp16_g)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step:
if self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._pipeline_block_reductions(block_id)
if self._full_pipeline:
self._pipeline_block_step(block_id)
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def has_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow = self._has_overflow
self._has_overflow = False
return has_overflow
@property
def peek_overflow(self):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return self._has_overflow
def strided_check_finite(self, output_params, stride=1, start=-1, end=-1, clear=True):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if start >= 0 and start < end:
out_p = output_params[start:end]
else:
out_p = output_params
fused_adam_cuda.strided_check_finite(self._overflow_buf,
out_p,
stride,
1 if clear else 0)
self._has_overflow = False if self._overflow_buf.item() == 0 else True
return self._has_overflow
@property
def L2_grad_norm(self):
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
else:
return None
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, grad_generated in enumerate(self._grads_generated):
if not grad_generated:
grad_info = self._grads_info[param_i]
param_offset = grad_info["param_offset"]
param_size = grad_info["param_grads_size"]
self._flat_grads[param_offset:param_offset+param_size].zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
for block_id in range(self._num_blocks-1,-1,-1):
self._pipeline_block_reductions(block_id)
if self._compute_L2_grad_norm:
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def revert_step(self):
"""Revert effect of previously calling partial_step.
"""
# Call undo kernel once per step
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p,
self._fp32_m,
self._fp32_v,
self._fp16_g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
if self._last_step or not self._overlap_reductions or not self._full_pipeline:
self._pipeline_step()
with torch.cuda.stream(self._completion_st):
# Check for overflow
# Store state for loss scaler calculation
has_overflow = False if skip_overflow_check else self.strided_check_finite(self._new_params, stride=self._shard_size, start=0, end=self._net_total_param_size)
if has_overflow:
self.revert_step()
else:
# Copy self._new_params to model params
for p in self._model_params: self.state[p]['step'] += 1
multi_tensor_applier(
fused_adam_cuda.maybe_cast_mt,
self._overflow_buf,
self._packed_flat_to_model_params)
torch.cuda.current_stream().wait_stream(self._completion_st)
self._reductions_works = [None]*self._num_blocks
self._allgather_works = [None]*self._num_blocks
return loss
import math
import torch
import importlib
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
class DistributedFusedAdamV3(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the 'update parameters' step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
overlap_reductions(boolean, optional): whether to overlap reductions
with bprop (default: True)
num_prestats (integer, optional): number of fp64 stats that will be
reduced during first fp16 gradient reduction block.
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0, overlap_reductions=True, full_pipeline=True,
compute_L2_grad_norm=False, distributed_weight_update=0,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_rs_pg=1, dwu_num_ar_pg=4,
dwu_num_ag_pg=0, revert_method=1, flat_mt=False,
dwu_num_chunks=4, predivide=True, e5m2_allgather=False,
do_not_flatten_model=False):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._amp_scale_adjustment = amp_scale_adjustment
if use_mt:
raise RuntimeError('DistributedFusedAdam does not support use_mt.')
if amsgrad:
raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(DistributedFusedAdamV3, self).__init__(params, defaults)
self.eps_mode = 0 if eps_inside_sqrt else 1
self._overflow_buf = torch.cuda.IntTensor([0])
assert (len(self.param_groups) == 1), "More than one parameter group is not supported."
# Way to revert a step
# 3 -> undo kernel + double buffer (debug, print norm of difference)
# 2 -> double buffer fp32 parameters
# 1 -> undo kernel
self._revert_method = revert_method
if self._revert_method > 1:
print("revert_method -> double buffer fp32 parameters, will consume more memory")
self._last_step = False
self._overlap_reductions = overlap_reductions
self._global_scale = None
self._num_blocks = dwu_num_blocks
self._predivide = predivide
self._e5m2_allgather = e5m2_allgather
self._do_not_flatten_model = do_not_flatten_model
self._full_pipeline = full_pipeline
self._L2_grad_norm = None
self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
self._world_size = torch.distributed.get_world_size()
self._num_groups = self._world_size // self._group_size
self._rank_in_group = torch.distributed.get_rank() % self._group_size
p_offset = 0
p_i = 0
self._param_state = None
self._model_params = []
self._grads_info = []
self._grad_accs = []
for group in self.param_groups:
self._param_group = group
prev = None
for p in group['params']:
torch.distributed.broadcast(p,0)
if not p.requires_grad:
continue
self._model_params.append(p)
state = self.state[p]
if len(state) == 0:
state['step'] = 0
if self._param_state is None:
self._param_state = state
p_grads_size = p.numel()
def wrapper(param, param_i, param_grads_size, param_offset):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
self._do_overlapped_reduction(param_i, param_grads_size, param_offset, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
p_i += 1
self._grads_generated = [False]*len(self._grads_info)
self._flat_mt = flat_mt
self._grads = []
self._current_block = self._num_blocks
self._net_total_param_size = p_offset
self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._block_size = self._total_param_size // self._num_blocks
self._shard_size = self._total_param_size // self._group_size
print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._shard_size))
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
print(self._low_param_i)
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._flat_params = torch.zeros_like(self._flat_grads)
def _flat_split(flat):
def __flat_blockify(flat):
return [flat[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __flat_shardify(flat):
return [flat[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
return __flat_blockify(flat), __flat_shardify(flat)
self._flat_grads_blocks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._flat_params_blocks, self._flat_params_shards = _flat_split(self._flat_params)
# master params
self._fp32_p = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_m = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
self._fp32_v = torch.zeros([self._shard_size], dtype=torch.float32, device='cuda')
# copy model params to flat_params and set_ model params to flat_params.
self._individual_flat_grads = []
with torch.no_grad():
for p, grads_info in zip(self._model_params, self._grads_info):
start = grads_info["param_offset"]
end = start + grads_info["param_grads_size"]
flat_p = self._flat_params[start:end].view_as(p)
flat_p.copy_(p)
p.set_(flat_p)
flat_grad = self._flat_grads[start:end]
self._individual_flat_grads.append(flat_grad)
self._fp32_p.copy_(self._flat_params_shards[self._rank_in_group].float())
self._dwu_st = torch.cuda.Stream()
self._l2_grad_norm_st = torch.cuda.Stream()
for group_i in range(self._num_groups):
ranks = [group_i*self._group_size+local_rank for local_rank in range(self._group_size)]
pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._ag_pg = pg
torch.distributed.all_reduce(self._overflow_buf, group=self._ag_pg)
import inspect
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
@property
def has_overflow(self):
return True if not self.L2_grad_norm is None and not math.isfinite(self.L2_grad_norm) else False
def set_last_step(self, last_step):
self._last_step = last_step
def _get_flush_block(self):
flush_block = []
if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
num_grads = len(self._grads_generated)
contiguous_idx = num_grads
while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
contiguous_idx -= 1
if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
self._current_block -= 1
start = self._current_block * self._block_size
end = (self._current_block+1) * self._block_size
flush_block = [start, end]
return flush_block
def __launch_step_kernel(self, p, p_copy, m, v, g):
combined_scale = self._global_scale
if self._param_group['max_grad_norm'] > 0 and math.isfinite(self.L2_grad_norm):
combined_scale = self._param_group['max_grad_norm'] / (self.L2_grad_norm / self._global_scale + 1e-6)
combined_scale = self._global_scale / min(1, combined_scale)
bias_correction = 1 if self._param_group['bias_correction'] else 0
beta1, beta2 = self._param_group['betas']
fused_adam_cuda.reversible_adam(
p, p_copy, m, v, g,
self._param_group['lr'],
beta1,
beta2,
self._param_group['eps'],
combined_scale,
self._param_state['step']+1,
self.eps_mode,
bias_correction,
self._param_group['weight_decay'])
def _flatten_grad_mt(self, scale):
if self._flat_mt and len(self._grads) > 0:
self._overflow_buf.zero_()
multi_tensor_applier(
amp_C.multi_tensor_scale,
self._overflow_buf,
list(zip(*self._grads)),
scale)
self._grads = []
def _do_overlapped_reduction(self, param_i, param_grads_size, param_offset, param):
# handle overlapped reductions
if self._flat_mt:
self._grads.append( (param.grad, self._individual_flat_grads[param_i]) )
else:
torch.div(param.grad, self._world_size if self._predivide else 1.0, out=self._individual_flat_grads[param_i])
self._grads_generated[param_i]=True
if not self._last_step and self._overlap_reductions:
flush_block = self._get_flush_block()
while flush_block:
block_id = flush_block[0] // self._block_size
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads_blocks[block_id])
if block_id == 0:
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
flush_block = self._get_flush_block()
def set_global_scale(self, global_scale):
"""Set global scale.
"""
self._global_scale = global_scale
@property
def global_scale(self):
return self._global_scale
@property
def L2_grad_norm(self):
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
return self._L2_grad_norm
def complete_reductions(self):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if self._last_step:
# zero out gradients that have not been completed yet
for param_i, flat_grad in enumerate(self._individual_flat_grads):
if not self._grads_generated[param_i]:
flat_grad.zero_()
self._grads_generated[param_i] = True
if self._last_step or not self._overlap_reductions:
# nothing done so far, run full pipeline after reductions
self._dwu_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._dwu_st):
self._flatten_grad_mt(1.0/self._world_size if self._predivide else 1.0)
torch.distributed.all_reduce(self._flat_grads)
self._l2_grad_norm_st.wait_stream(self._dwu_st)
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float32, p=2).item()
self._current_block = self._num_blocks
self._grads_generated = [False]*len(self._grads_info)
def step(self, closure=None, skip_overflow_check=False):
loss = None
if closure is not None:
loss = closure()
with torch.cuda.stream(self._dwu_st):
self.__launch_step_kernel(
self._fp32_p,
self._flat_params_shards[self._rank_in_group],
self._fp32_m,
self._fp32_v,
self._flat_grads_shards[self._rank_in_group])
torch.distributed.all_gather(self._flat_params_shards, self._flat_params_shards[self._rank_in_group], group=self._ag_pg, no_copy=True)
for p in self._model_params: self.state[p]['step'] += 1
torch.cuda.current_stream().wait_stream(self._dwu_st)
return loss
from .peer_memory import PeerMemoryPool
from .peer_halo_exchanger_1d import PeerHaloExchanger1d
import torch
from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d
import peer_memory_cuda as pm
# How to run:
# torchrun --nproc_per_node <num-GPU> <this-python-prog>
# <num-GPU> must be a power of 2 greater than 1.
# Output of this function is used as ground truth in module tests.
def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split):
if explicit_nhwc:
if H_split:
_, Hp, _, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,half_halo:2*half_halo,:,:]
top_inp_halo = y[:,:half_halo,:,:]
btm_out_halo = y[:,H:H+half_halo,:,:]
btm_inp_halo = y[:,H+half_halo:H+2*half_halo,:,:]
else:
_, _, Wp, _ = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,W:W+half_halo,:]
btm_inp_halo = y[:,:,W+half_halo:W+2*half_halo,:]
else:
if H_split:
_, _, Hp, _ = list(y.shape)
H = Hp - 2*half_halo
top_out_halo = y[:,:,half_halo:2*half_halo,:]
top_inp_halo = y[:,:,:half_halo,:]
btm_out_halo = y[:,:,H:H+half_halo,:]
btm_inp_halo = y[:,:,H+half_halo:H+2*half_halo,:]
else:
_, _, _, Wp = list(y.shape)
W = Wp - 2*half_halo
top_out_halo = y[:,:,:,half_halo:2*half_halo]
top_inp_halo = y[:,:,:,:half_halo]
btm_out_halo = y[:,:,:,W:W+half_halo]
btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo]
top_out_halo = top_out_halo.clone(memory_format=torch.preserve_format)
btm_out_halo = btm_out_halo.clone(memory_format=torch.preserve_format)
top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(top_inp_halos, top_out_halo)
btm_inp_halos = [torch.empty_like(btm_out_halo) for _ in range(peer_group_size)]
torch.distributed.all_gather(btm_inp_halos, btm_out_halo)
top_rank = (peer_rank + peer_group_size - 1) % peer_group_size
btm_rank = (peer_rank + 1) % peer_group_size
top_inp_halo.copy_(btm_inp_halos[top_rank])
btm_inp_halo.copy_(top_inp_halos[btm_rank])
def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1):
if memory_format == 1:
# 1 -> explicit nhwc
explicit_nhwc = True
if H_split:
y = torch.randn([1,H+2*half_halo,W,C], dtype=dtype, device='cuda')
ym = y[:,half_halo:H+half_halo,:,:]
else:
y = torch.randn([1,H,W+2*half_halo,C], dtype=dtype, device='cuda')
ym = y[:,:,half_halo:W+half_halo,:]
else:
# 2 -> native nhwc
# 3 -> nchw
explicit_nhwc = False
if H_split:
y = torch.randn([1,C,H+2*half_halo,W], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,half_halo:H+half_halo,:]
else:
y = torch.randn([1,C,H,W+2*half_halo], dtype=dtype, device='cuda')
if memory_format == 2:
y = y.to(memory_format=torch.channels_last)
ym = y[:,:,:,half_halo:W+half_halo]
y3 = y.clone()
list_y = []
for step in range(num_steps):
halo_ex(y, H_split, explicit_nhwc, numSM)
list_y.append(y.clone())
y.copy_(y3)
halo_ex.peer_pool.reset()
torch.distributed.barrier()
y2 = y3.clone()
list_y2 = []
for step in range(num_steps):
nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split)
list_y2.append(y2.clone())
y2.copy_(y3)
is_equal = [torch.all(torch.eq(yy,yy2)) for yy,yy2 in zip(list_y,list_y2)]
is_equal = torch.tensor(is_equal, dtype=torch.bool)
is_equal = torch.all(is_equal)
if peer_rank == 0:
if memory_format == 1:
memory_format_str = "explicit_nhwc"
elif memory_format == 2:
memory_format_str = "native nhwc"
elif memory_format == 3:
memory_format_str = "nchw"
else:
memory_format_str = "???"
if is_equal:
print("SUCCESS : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
else:
print("FAILURE : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split"))
# peer memory flag sync relies on there being at least one barrier per step
torch.distributed.barrier()
def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Hr = 8*world_size
Hp = ((H + Hr - 1) // Hr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True, num_steps)
single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True, num_steps)
def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps):
Wr = 8*world_size
Wp = ((W + Wr - 1) // Wr) * 8
for i in range(4):
div = int(pow(2,i))
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False, num_steps)
single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False, num_steps)
def main():
# for this trivial example peer_rank == rank and peer_group_size == world_size
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank)
pool = PeerMemoryPool(rank, world_size, world_size, 64*1024, 2*1024*1024)
num_steps = 100
half_halo = 1
halo_ex = PeerHaloExchanger1d(rank, world_size, pool, half_halo)
H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps)
W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps)
if __name__ == "__main__":
main()
import torch
from apex.contrib.peer_memory import PeerMemoryPool
import peer_memory_cuda as pm
class PeerHaloExchanger1d:
def __init__(self, rank, peer_group_size, peer_pool, half_halo):
self.peer_group_size = peer_group_size
self.peer_rank = rank % peer_group_size
self.peer_pool = peer_pool
self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False)
self.signals[self.peer_rank].zero_()
self.half_halo = half_halo
def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=False):
channels_last = y.is_contiguous(memory_format=torch.channels_last) and not explicit_nhwc
if H_split:
if explicit_nhwc:
_, Hs, _, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,self.half_halo:2*self.half_halo,:,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:self.half_halo,:,:]
btm_out_halo = y[:,H:H+self.half_halo,:,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:]
else:
_, _, Hs, _ = list(y.shape)
H = Hs - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,H:H+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:]
else:
if explicit_nhwc:
_, _, Ws, _ = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,self.half_halo:2*self.half_halo,:]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, False, True)
top_inp_halo = y[:,:,:self.half_halo,:]
btm_out_halo = y[:,:,W:W+self.half_halo,:]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, False, True)
btm_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:]
else:
_, _, _, Ws = list(y.shape)
W = Ws - 2*self.half_halo
top_out_halo = y[:,:,:,self.half_halo:2*self.half_halo]
top_tx = self.peer_pool.allocate_peer_tensors(list(top_out_halo.shape), top_out_halo.dtype, channels_last, True)
top_inp_halo = y[:,:,:,:self.half_halo]
btm_out_halo = y[:,:,:,W:W+self.half_halo]
btm_tx = self.peer_pool.allocate_peer_tensors(list(btm_out_halo.shape), btm_out_halo.dtype, channels_last, True)
btm_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo]
top_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size
btm_neighbor = (self.peer_rank + 1) % self.peer_group_size
pm.push_pull_halos_1d(
diagnostics, explicit_nhwc, numSM,
top_out_halo, top_tx[self.peer_rank], btm_tx[top_neighbor], top_inp_halo,
btm_out_halo, btm_tx[self.peer_rank], top_tx[btm_neighbor], btm_inp_halo,
self.signals[top_neighbor], self.signals[btm_neighbor], self.signals[self.peer_rank]
)
import torch
import numpy as np
import peer_memory_cuda as pm
class PeerMemoryPool(object):
def __init__(self, static_size, dynamic_size, peer_ranks=None):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
ngpus = min(torch.cuda.device_count(), world_size)
peer_group_size = ngpus
peer_group = rank // ngpus
peer_rank_base = peer_group * ngpus
peer_rank = rank - peer_rank_base
if peer_ranks is None:
peer_ranks = [i+peer_rank_base for i in range(peer_group_size)]
peer_rank_start = peer_rank_base
peer_rank_end = peer_rank_start + peer_group_size - 1
for pr in peer_ranks:
assert(pr >= peer_rank_start and pr <= peer_rank_end), "%d :: peer_rank %d not on same node (ranks=[%d,%d])" % (rank, pr, peer_rank_start, peer_rank_end)
self.alignment = 256
self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment
# allocate giant pool of device memory
self.raw = pm.allocate_raw(self.static_size+self.dynamic_size)
# exchange peer pointers with nccl
raw_ipc = pm.get_raw_ipc_address(self.raw).cuda()
peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]
torch.distributed.all_gather(peer_raw_ipcs, raw_ipc)
peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu()
# extract IPC pointers for ranks on same node
peer_raw = pm.get_raw_peers(peer_raw_ipcs[peer_rank_base:peer_rank_base+ngpus], peer_rank, self.raw)
self.peer_raw = [peer_raw[peer_rank-peer_rank_base] for peer_rank in peer_ranks]
self.static_offset = 0
self.dynamic_offset = 0
self.peer_ranks = peer_ranks
def __del__(self):
pm.free_raw(self.raw)
def reset(self):
self.dynamic_offset = 0
def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic):
nels = np.prod(shape)
if dtype == torch.float16:
elem_size = 2
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_half(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_half(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.float32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_float(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_float(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.int32:
elem_size = 4
if dynamic:
start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_offset = start + nels * elem_size
assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_int(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_int(pr + start, shape, channels_last) for pr in self.peer_raw]
else:
assert(False), "dtype %s not supported" % (str(dtype))
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python. This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python.
## Importing ASP ## Importing ASP
``` ```
from apex.contrib.sparsity import ASP from apex.contrib.sparsity import ASP
``` ```
...@@ -10,11 +11,13 @@ from apex.contrib.sparsity import ASP ...@@ -10,11 +11,13 @@ from apex.contrib.sparsity import ASP
## Initializing ASP ## Initializing ASP
Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference: Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
``` ```
In the context of a typical PyTorch training loop, it might look like this: In the context of a typical PyTorch training loop, it might look like this:
``` ```
ASP.prune_trained_model(model, optimizer) ASP.prune_trained_model(model, optimizer)
...@@ -27,6 +30,7 @@ for epoch in range(epochs): ...@@ -27,6 +30,7 @@ for epoch in range(epochs):
torch.save(...) torch.save(...)
``` ```
The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step.
## Generate a Sparse Network ## Generate a Sparse Network
...@@ -42,7 +46,6 @@ The following approach serves as a guiding example on how to generate a pruned m ...@@ -42,7 +46,6 @@ The following approach serves as a guiding example on how to generate a pruned m
In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above). In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above).
``` ```
model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint) model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint)
criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model
optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model
...@@ -72,7 +75,60 @@ ASP.compute_sparse_masks() ...@@ -72,7 +75,60 @@ ASP.compute_sparse_masks()
A more thorough example can be found in `./test/toy_problem.py`. A more thorough example can be found in `./test/toy_problem.py`.
## Advanced Usage: Channel Permutation
We introduce channel permutations as an advanced method to maximize the accuracy of structured sparse networks. By permuting weight matrices along their channel dimension and adjusting the surrounding layers appropriately, we demonstrate accuracy recovery for even small, parameter-efficient networks, without affecting inference run-time.
The final accuracy has a strong relationship with the quality of permutations. We provide the default algorithms to search for high-quality permutations. The permutation search process can be accelerated by the Apex CUDA extension: `apex.contrib.sparsity.permutation_search_kernels`
If you want to use the GPU to accelerate the permutation search process, we recommend installing Apex with permutation search CUDA extension via
```
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--permutation_search" ./
```
If you want to disable the permutation search process, please pass the `allow_permutation=False` to `init_model_for_pruning` function. For example:
```
ASP.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False, allow_permutation=False)
```
Please notice, when using multi-GPUs we should set the identical random seed for all GPUs to make sure the same results generated in permutation search. The library has implemented the `set_identical_seed` function in `permutation_lib.py`, and be called in ASP library. We still suggest the users to set the identical random seed when using multi-GPUs in their code, the example code is as follows:
```
import torch
import numpy
import random
torch.manual_seed(identical_seed)
torch.cuda.manual_seed_all(identical_seed)
numpy.random.seed(identical_seed)
random.seed(identical_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
```
## Reference Papers
More details about sparsity support on the NVIDIA Ampere GPU with Sparse Tensor Cores can refer to our [white paper](https://arxiv.org/abs/2104.08378).
```
@article{mishra2021accelerating,
title={Accelerating sparse deep neural networks},
author={Mishra, Asit and Latorre, Jorge Albericio and Pool, Jeff and Stosic, Darko and Stosic, Dusan and Venkatesh, Ganesh and Yu, Chong and Micikevicius, Paulius},
journal={arXiv preprint arXiv:2104.08378},
year={2021}
}
```
The details about sparsity with permutation can refer to our [paper](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html) published in *Thirty-fifth Conference on Neural Information Processing Systems* (**NeurIPS 2021**):
```
@article{pool2021channel,
title={Channel Permutations for N: M Sparsity},
author={Pool, Jeff and Yu, Chong},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
```
import types import types
import torch import torch
from .sparse_masklib import create_mask from .sparse_masklib import create_mask
from .permutation_lib import Permutation
torchvision_imported=True torchvision_imported=True
try: try:
...@@ -9,6 +10,11 @@ except ImportError: ...@@ -9,6 +10,11 @@ except ImportError:
print("[ASP][Warning] torchvision cannot be imported.") print("[ASP][Warning] torchvision cannot be imported.")
torchvision_imported=False torchvision_imported=False
import json
import os
import string
import time
def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names): def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names):
eligible_modules_list = [] eligible_modules_list = []
for name, mod in model.named_modules(): for name, mod in model.named_modules():
...@@ -18,19 +24,25 @@ def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallow ...@@ -18,19 +24,25 @@ def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallow
eligible_modules_list.append((name, mod)) eligible_modules_list.append((name, mod))
return eligible_modules_list return eligible_modules_list
class ASP: class ASP:
__model = None __model = None
__verbosity = 0 __verbosity = 0
__optimizer = None __optimizer = None
__sparse_parameters = [] __sparse_parameters = []
__calculate_mask = None __calculate_mask = None
__allow_permutation = True
__all_parameters = []
__save_permutation_graph = False
__permutation_output_dir = ''
@classmethod @classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3, verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d],
allowed_layer_names=None, disallowed_layer_names=[], allowed_layer_names=None, disallowed_layer_names=[],
allow_recompute_mask=False, custom_layer_dict={}): allow_recompute_mask=False, custom_layer_dict={},
allow_permutation=True):
"""Call this method to modify your model to take advantage of sparse matrix multiplication. """Call this method to modify your model to take advantage of sparse matrix multiplication.
Note that this call alone only augments the model with additional buffers needed for sparse MMA, Note that this call alone only augments the model with additional buffers needed for sparse MMA,
it does not enable use of sparse MMA. it does not enable use of sparse MMA.
...@@ -63,12 +75,14 @@ class ASP: ...@@ -63,12 +75,14 @@ class ASP:
allow_recompute_mask If True, stores pruned values so that dense weights can be restored. allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage. Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']} custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning.
[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe -- AKM. [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe.
""" """
assert (cls.__model is None), "ASP has been initialized already." assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model cls.__model = model
cls.__verbosity = verbosity cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
if isinstance(mask_calculator, str): if isinstance(mask_calculator, str):
def create_mask_from_pattern(param): def create_mask_from_pattern(param):
...@@ -91,6 +105,28 @@ class ASP: ...@@ -91,6 +105,28 @@ class ASP:
for module_type in whitelist: for module_type in whitelist:
assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype() assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype()
if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim
for module_name, module in model.named_modules():
module_type_str = str(type(module)).split("\'")[1]
if module_type_str == 'torch.nn.modules.container.Sequential' or module_type_str.startswith('torchvision.models'):
# filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG'
continue
for p_name, p in module.named_parameters():
cls.__all_parameters.append((module_name, module, p_name, p))
if module_type_str == 'torch.nn.modules.batchnorm.BatchNorm2d':
# need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters
module_mean_name = module_name + '.running_mean'
module_var_name = module_name + '.running_var'
for param_key in model.state_dict():
if module_mean_name == param_key or module_var_name == param_key:
cls.__all_parameters.append((module_name, module, param_key.split(".")[-1], model.state_dict()[param_key]))
# add the __permutation_output_dir field to save the intermediate results for permutation
cls.__permutation_output_dir = '.'
# Set the corresponding params from ASP class to the Permutation class
Permutation.set_permutation_params_from_asp(cls.__model, cls.__sparse_parameters, cls.__all_parameters)
# Set the identical random seed for all GPUs to make sure the same results generated in permutation search
Permutation.set_identical_seed()
# find all sparse modules, extract sparse parameters and decorate # find all sparse modules, extract sparse parameters and decorate
def add_sparse_attributes(module_name, module): def add_sparse_attributes(module_name, module):
sparse_parameters = sparse_parameter_list[type(module)] sparse_parameters = sparse_parameter_list[type(module)]
...@@ -123,6 +159,19 @@ class ASP: ...@@ -123,6 +159,19 @@ class ASP:
for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names): for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names):
add_sparse_attributes(name, sparse_module) add_sparse_attributes(name, sparse_module)
@classmethod
def already_init_asp_model(cls):
"""Call this method to check whether ASP has been initialized already.
"""
if cls.__model is None:
if cls.__verbosity >= 3:
print("[ASP] ASP has not been initialized.")
return False
else:
if cls.__verbosity >= 3:
print("[ASP] ASP has been initialized already.")
return True
@classmethod @classmethod
def init_optimizer_for_pruning(cls, optimizer): def init_optimizer_for_pruning(cls, optimizer):
"""Call this method to monkey patch optimizer step function so that masks can be applied to """Call this method to monkey patch optimizer step function so that masks can be applied to
...@@ -157,6 +206,38 @@ class ASP: ...@@ -157,6 +206,38 @@ class ASP:
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None. If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
""" """
with torch.no_grad(): with torch.no_grad():
if cls.__allow_permutation:
# Step 1: use the Torch.FX library to build the graph
# Step 2: permutation search with the customized kernel
# Notice: need to use the single GPU to build the Torch.FX graph
# The simplest without user intervention:
# A. try to import with the distributed mode of the original model
# B. if meet the error, import with the none-distributed mode of the original model
start_time_build_offline_permutation_graph = time.perf_counter()
try:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model.module, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on distributed model.")
except AttributeError:
offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json'))
print("\n[compute_sparse_masks] build offline permutation graph on none-distributed model.")
duration_build_offline_permutation_graph = time.perf_counter() - start_time_build_offline_permutation_graph
print("[compute_sparse_masks] Take {:.4f} seconds to finish build_offline_permutation_graph function.".format(duration_build_offline_permutation_graph))
# Step 3: off-line permutation to avoid the runtime overhead in deployment
if success_in_build_offline_permutation_graph:
start_time_apply_offline_permutation = time.perf_counter()
try:
Permutation.apply_offline_permutation(cls.__model.module, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on distributed model.")
except AttributeError:
Permutation.apply_offline_permutation(cls.__model, fx_graph=offline_permutation_fx_graph)
print("\n[compute_sparse_masks] apply offline permutation on none-distributed model.")
duration_apply_offline_permutation = time.perf_counter() - start_time_apply_offline_permutation
print("[compute_sparse_masks] Take {:.4f} seconds to finish apply_offline_permutation function.\n".format(duration_apply_offline_permutation))
else:
print("[compute_sparse_masks] skip applying offline permutation because there is no valid offline_permutation_fx_graph.")
# Finally, permutation search and off-line permutation is done, give the model back to ASP to generate the normal structured sparse mask
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled # restore dense parameter if allow_recompute_mask is enabled
...@@ -170,7 +251,7 @@ class ASP: ...@@ -170,7 +251,7 @@ class ASP:
p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights
if cls.__verbosity >= 2: if cls.__verbosity >= 2:
print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype))) print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0-100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype)))
@classmethod @classmethod
def restore_pruned_weights(cls): def restore_pruned_weights(cls):
...@@ -215,3 +296,17 @@ class ASP: ...@@ -215,3 +296,17 @@ class ASP:
cls.init_optimizer_for_pruning(optimizer) cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks() cls.compute_sparse_masks()
@classmethod
def set_permutation_saving_params(cls, allow_permutation=True, save_permutation_graph=False, permutation_output_dir='.'):
"""This function is used to set the permutation saving related parameters in ASP class and inside of the Permutation class."""
print("\n[ASP][set_permutation_saving_param] Set permutation saving related parameters")
print("\n[set_permutation_saving_param] Set permutation saving related parameters")
cls.__allow_permutation = allow_permutation
print("[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation))
cls.__save_permutation_graph = save_permutation_graph
print("[set_permutation_saving_param]\t Save permutation graphs: {}".format(cls.__save_permutation_graph))
cls.__permutation_output_dir = permutation_output_dir
print("[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(cls.__permutation_output_dir))
Permutation.set_permutation_saving_params(allow_permutation, save_permutation_graph, permutation_output_dir)
import os
import torch
import json
import string
import time
try:
from .permutation_search_kernels import accelerated_search_for_good_permutation, sum_after_2_to_4
print("[ASP][Info] permutation_search_kernels can be imported.")
except ImportError:
print("[ASP][Warning] permutation_search_kernels cannot be imported.")
print("[ASP][Warning] If you want to accelerate the permutation search process by GPU, please build APEX by following the instructions at https://github.com/NVIDIA/apex/blob/master/apex/contrib/sparsity/README.md")
def convert_fx_node_name(fx_node_name):
converted_fx_node_name = fx_node_name
converted_fx_node_name = converted_fx_node_name.replace('_', '.')
return converted_fx_node_name
def get_node_parent_children(fx_node):
# get node parent list, and convert node name to module name
node_parent_name_converted = []
if len(fx_node.all_input_nodes) > 0:
node_parent = fx_node.all_input_nodes
for item in node_parent:
converted_item = convert_fx_node_name(item.name)
node_parent_name_converted.append(converted_item)
else:
node_parent = list('None')
node_parent_name_converted.append('None')
# get node children list, and convert node name to module name
node_children_name_converted = []
if len(list(fx_node.users.keys())) > 0:
node_children = list(fx_node.users.keys())
for item in node_children:
converted_item = convert_fx_node_name(item.name)
node_children_name_converted.append(converted_item)
else:
node_children = list('None')
node_children_name_converted.append('None')
return node_parent_name_converted, node_children_name_converted
class Permutation:
__model = None
__sparse_parameters = []
__allow_permutation = False
__all_parameters = []
__save_permutation_graph = False
__permutation_output_dir = ''
@classmethod
def set_permutation_params_from_asp(cls, model, sparse_parameters, all_parameters):
"""This function is used to set the permutation needed parameters from ASP class."""
print("\n[set_permutation_params_from_asp] Set permutation needed parameters")
cls.__model = model
cls.__sparse_parameters = sparse_parameters
cls.__all_parameters = all_parameters
@classmethod
def set_identical_seed(cls, identical_seed=1):
print("\n[set_identical_seed] Set the identical seed: {:} for all GPUs to make sure the same results generated in permutation search".format(identical_seed))
torch.manual_seed(identical_seed)
torch.cuda.manual_seed_all(identical_seed)
import numpy as np
import random
np.random.seed(identical_seed)
random.seed(identical_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@classmethod
def set_permutation_saving_params(cls, allow_permutation=False, save_permutation_graph=False, permutation_output_dir='.'):
"""This function is used to set the permutation saving related parameters."""
print("\n[permutation_lib][set_permutation_saving_param] Set permutation saving related parameters")
cls.__allow_permutation = allow_permutation
print("[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation))
cls.__save_permutation_graph = save_permutation_graph
print("[set_permutation_saving_param]\t Save permutation graphs: {}".format(cls.__save_permutation_graph))
cls.__permutation_output_dir = permutation_output_dir
print("[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(cls.__permutation_output_dir))
@classmethod
def apply_offline_permutation(cls, model, fx_graph):
"""This function is used to offline permutation for each node according to the the whole network graph built with Torch.FX."""
print("\n[apply_offline_permutation] Offline permutation for each node according to the the whole network graph built with Torch.FX")
# Firstly, we should transfer the sparse mask to all-one dense mask
cls.transfer_to_dense_mask()
for node_name in fx_graph.keys():
node_module_type = fx_graph.get(node_name).get('module_type')
# check wheter the current layer can permute as plan, e.g., the flatten layer in VGG will change the shape and broke the permutation chain
# only need to check the 'is_node_real_parents_K_permuted', because the 'is_node_real_parents_C_permuted' has no influence to the children
node_real_parents = fx_graph.get(node_name).get('real_parents')
is_node_real_parents_K_permuted = True
if node_real_parents is not None: # filter out the 'unique_siblings' item
for real_parent_item in node_real_parents:
if fx_graph.get(real_parent_item).get('permutation_type') in ['K', 'KC']:
if fx_graph.get(real_parent_item).get('k_permuted') == 'False':
is_node_real_parents_K_permuted = False
if fx_graph[node_name]['permutation_type'] == 'KC': # intermediate Conv, FC
C_permutation_sequence = cls.fetch_C_permutation_sequence_value(node_name, fx_graph)
K_permutation_sequence = cls.fetch_K_permutation_sequence_value(node_name, fx_graph)
print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in K and C dims.".format(node_name, node_module_type))
if is_node_real_parents_K_permuted == True:
fx_graph[node_name]['c_permuted'] = str(cls.apply_permutation_in_C_dim(node_name, C_permutation_sequence))
fx_graph[node_name]['k_permuted'] = str(cls.apply_permutation_in_K_dim(node_name, K_permutation_sequence))
else:
print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in C dim.".format(node_name, node_module_type))
fx_graph[node_name]['k_permuted'] = str(cls.apply_permutation_in_K_dim(node_name, K_permutation_sequence))
elif fx_graph[node_name]['permutation_type'] == 'K': # BN, first layer Conv/FC
K_permutation_sequence = cls.fetch_K_permutation_sequence_value(node_name, fx_graph)
print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in K dim.".format(node_name, node_module_type))
if is_node_real_parents_K_permuted == True:
fx_graph[node_name]['k_permuted'] = str(cls.apply_permutation_in_K_dim(node_name, K_permutation_sequence))
else: # for BN, if the previous Conv cannot do permutation in K dim, then no need to do permutation in K dim for this BN
print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in K dim.".format(node_name, node_module_type))
elif fx_graph[node_name]['permutation_type'] == 'C': # last layer FC/Conv
C_permutation_sequence = cls.fetch_C_permutation_sequence_value(node_name, fx_graph)
print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in C dim.".format(node_name, node_module_type))
if is_node_real_parents_K_permuted == True:
fx_graph[node_name]['c_permuted'] = str(cls.apply_permutation_in_C_dim(node_name, C_permutation_sequence))
else:
print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in C dim.".format(node_name, node_module_type))
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_apply_offline_permutation.json')) # save the intermediate graph as JSON file for debugging
return fx_graph
@classmethod
def transfer_to_dense_mask(cls):
"""Call this method to transfer the sparse mask to all-one dense mask."""
with torch.no_grad():
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
mask.fill_(1)
@classmethod
def fetch_C_permutation_sequence_value(cls, node_name, fx_graph):
"""This function is used to fetch the permutation sequence value in C dim from the unique_siblings record."""
# C_permutation_sequence is the corresponding 'permutation_sequence' value stored in the fx_graph.get('unique_siblings') item which contains node_name
unique_siblings_groups = fx_graph.get('unique_siblings').get('name')
unique_siblings_groups_permutation_sequence = fx_graph.get('unique_siblings').get('permutation_sequence')
item_index = 0
fetched_C_permutation_sequence = []
for item in unique_siblings_groups:
if node_name in item:
fetched_C_permutation_sequence = unique_siblings_groups_permutation_sequence[item_index]
item_index = item_index + 1
return fetched_C_permutation_sequence
@classmethod
def fetch_K_permutation_sequence_value(cls, node_name, fx_graph):
"""This function is used to fetch the permutation sequence value in K dim from the unique_siblings record."""
# K_permutation_sequence is its real_children's corresponding 'permutation_sequence' value stored in the fx_graph.get('unique_siblings') item which contains real_children name
# we have the assumption that all the real children are in one unique_sibling group, so should share the same permutation_sequence value
unique_siblings_groups = fx_graph.get('unique_siblings').get('name')
unique_siblings_groups_permutation_sequence = fx_graph.get('unique_siblings').get('permutation_sequence')
node_real_children = fx_graph.get(node_name).get('real_children')
fetched_K_permutation_sequence = []
if len(node_real_children) > 0:
node_representative_child = node_real_children[0]
fetched_K_permutation_sequence = cls.fetch_C_permutation_sequence_value(node_representative_child, fx_graph)
return fetched_K_permutation_sequence
@classmethod
def apply_permutation_in_C_dim(cls, node_name, permutation_sequence):
"""This function is used to permutation for a node in C dim. (Only need to handle the weight of the node) """
print("[apply_permutation_in_C_dim] Permutation for node: \'{:}\' in C dim".format(node_name))
if len(permutation_sequence) == 0:
print("[apply_permutation_in_C_dim] the permutation sequence is empty, fail to apply permutation in C dim.")
return False
is_node_in_sparse_parameters = False
success_permutation = False
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower()
processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower()
distributed_node_name = 'module.' + node_name
processed_distributed_node_name = 'module.' + processed_node_name
if (module_name == node_name) or (module_name == distributed_node_name) or (processed_module_name == processed_node_name) or (processed_module_name == processed_distributed_node_name): # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv
print("[apply_permutation_in_C_dim] find the node: \'{:}\' in cls.__sparse_parameters, succeed to apply permutation in C dim.".format(node_name))
is_node_in_sparse_parameters = True
temp_weight = torch.zeros_like(p)
temp_weight.copy_(p[:, permutation_sequence, ...])
p.data.copy_(temp_weight)
success_permutation = True
if is_node_in_sparse_parameters == False:
# A special case: if the node itself not in sparse_module_names but one of its real_siblings in sparse_module_names, then the node will not do the permutation search, but it may need to apply the offline permutation in C dim according to the searched permutation sequence from its real_siblings in sparse_module_names
try:
for module_name_from_all_parameters, module_from_all_parameters, p_name_from_all_parameters, p_from_all_parameters in cls.__all_parameters:
if ((node_name == module_name_from_all_parameters) or ('module.' + node_name == module_name_from_all_parameters)) and p_name_from_all_parameters == "weight":
print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, but can find in cls.__all_parameters.".format(node_name))
temp_weight = torch.zeros_like(p_from_all_parameters)
temp_weight.copy_(p_from_all_parameters[:, permutation_sequence, ...])
p_from_all_parameters.data.copy_(temp_weight)
success_permutation = True
print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, after trying with cls.__all_parameters, succeed to apply permutation in C dim.".format(node_name))
except:
success_permutation = False
print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, after trying with cls.__all_parameters, still fail to apply permutation in C dim.".format(node_name))
return success_permutation
@classmethod
def apply_permutation_in_K_dim(cls, node_name, permutation_sequence):
"""This function is used to permutation for a node in K dim. (Need to handle the weight/bias/running_mean/running_var of the node)"""
print("[apply_permutation_in_K_dim] Permutation for node: \'{:}\' in K dim".format(node_name))
if len(permutation_sequence) == 0:
print("[apply_permutation_in_K_dim] the permutation sequence is empty, fail to apply permutation in K dim.")
return False
is_node_in_all_parameters = False
success_permutation = False
for module_name, module, p_name, p in cls.__all_parameters:
processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower()
processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower()
distributed_node_name = 'module.' + node_name
processed_distributed_node_name = 'module.' + processed_node_name
if (module_name == node_name) or (module_name == distributed_node_name) or (processed_module_name == processed_node_name) or (processed_module_name == processed_distributed_node_name): # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv
print("[apply_permutation_in_K_dim] find the node: \'{:}\' with \'{:}\' in cls.__all_parameters, may succeed to apply permutation in K dim.".format(node_name, p_name))
is_node_in_all_parameters = True
temp_weight = torch.zeros_like(p)
if p.shape[0] != len(permutation_sequence):
print("[apply_permutation_in_K_dim][warning] the node: \'{:}\' with shape: \'{:}\', cannot match the size of permutation sequence with len: \'{:}\', fail to apply permutation in K dim.".format(node_name, p.shape, len(permutation_sequence)))
success_permutation = False
else:
print("[apply_permutation_in_K_dim] the node: \'{:}\' with shape: \'{:}\', can match the size of permutation sequence with len: \'{:}\', succeed to apply permutation in K dim.".format(node_name, p.shape, len(permutation_sequence)))
temp_weight.copy_(p[permutation_sequence, ...])
p.data.copy_(temp_weight)
success_permutation = True
if is_node_in_all_parameters == False:
print("[apply_permutation_in_K_dim] cannot find the node: \'{:}\' in cls.__all_parameters, fail to apply permutation in K dim.".format(node_name))
success_permutation = False
return success_permutation
@classmethod
def build_offline_permutation_graph(cls, model, dump_fx_graph=False, save_dumped_fx_graph='./model_offline_permutation_graph.json'):
"""This function is used to refine the whole network graph built with Torch.FX with some extra infomation needed for offline permutation."""
print("\n[build_offline_permutation_graph] Further refine the model graph built by Torch.FX for offline permutation")
# extract the output_dir, so all the intermediate fx_graph can be saved under that path
extract_output_dir=os.path.split(save_dumped_fx_graph)[0]
cls.__permutation_output_dir = extract_output_dir
fx_graph, success_in_build_fx_graph = cls.build_fx_graph(model, dump_fx_graph=dump_fx_graph, save_dumped_fx_graph=save_dumped_fx_graph)
if success_in_build_fx_graph:
fx_graph_after_find_real_parents = cls.find_real_parents(fx_graph)
fx_graph_after_find_real_children = cls.find_real_children(fx_graph_after_find_real_parents)
fx_graph_after_find_real_siblings = cls.find_real_siblings(fx_graph_after_find_real_children)
fx_graph_after_extract_all_unique_siblings = cls.extract_all_unique_siblings(fx_graph_after_find_real_siblings)
fx_graph_after_init_permutation_flag = cls.init_permutation_flag(fx_graph_after_extract_all_unique_siblings)
start_time_search_for_good_permutation = time.perf_counter()
fx_graph_after_search_for_good_permutation = cls.search_for_good_permutation(fx_graph_after_init_permutation_flag)
duration_search_for_good_permutation = time.perf_counter() - start_time_search_for_good_permutation
print("\n[build_offline_permutation_graph] Take {:.4f} seconds to finish search_for_good_permutation function.".format(duration_search_for_good_permutation))
else:
fx_graph_after_search_for_good_permutation = {}
return fx_graph_after_search_for_good_permutation, success_in_build_fx_graph
# Please notice the apply_offline_permutation step cannot fold into the above search_for_good_permutation step.
# Because the real_parent node needs to offline permutation in K direction according to the searched permutation sequence from its real_children.
# However, when we search_for_good_permutation for the node, its real_children have not been handled by search_for_good_permutation.
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph_after_search_for_good_permutation, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_build_offline_permutation_graph.json')) # save the intermediate graph as JSON file for debugging
return fx_graph_after_search_for_good_permutation, success_in_build_fx_graph
@classmethod
def search_for_good_permutation(cls, fx_graph):
"""This function is used to:
1. search for the good permutation sequence for each node weight, or each siblings_group weights by calling the permutation search kernels as ASP extension.
2. add the searched permutation sequence for each node according to the whole network graph built with Torch.FX."""
print("\n[search_for_good_permutation] Search for the good permutation sequence for each node according to the whole network graph built with Torch.FX")
unique_siblings_groups = fx_graph.get('unique_siblings').get('name')
unique_siblings_groups_module_type = fx_graph.get('unique_siblings').get('module_type')
unique_siblings_groups_permutation_sequence = []
item_index = 0
for unique_siblings_group in unique_siblings_groups: # loop through all unique siblings groups that must share a permutation sequence
print("\n[search_for_good_permutation] this unique_siblings_group has {:} real siblings: \'{:}\', with module type: \'{:}\'.".format(len(unique_siblings_group), unique_siblings_group, unique_siblings_groups_module_type[item_index]))
item_index = item_index + 1
# concat the weight for layers in the same unique_siblings_group
matrix_group = None
for node_name in unique_siblings_group:
node_module_type = fx_graph.get(node_name).get('module_type')
print("[search_for_good_permutation] try to merge the weight for node: \'{:}\', with module type: \'{:}\'.".format(node_name, node_module_type))
is_node_in_sparse_parameters = False
node_weight = torch.zeros(0)
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower()
processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower()
distributed_node_name = 'module.' + node_name
processed_distributed_node_name = 'module.' + processed_node_name
if (module_name == node_name) or (module_name == distributed_node_name) or (processed_module_name == processed_node_name) or (processed_module_name == processed_distributed_node_name): # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv
module_type_from_sparse_parameters = str(type(module)) # e.g. <class 'torch.nn.modules.conv.Conv2d'>
module_type_from_sparse_parameters = module_type_from_sparse_parameters[8:-2]
print("[search_for_good_permutation] find the node: \'{:}\' in cls.__sparse_parameters, module type match: \'{:}\'.".format(node_name, node_module_type==module_type_from_sparse_parameters))
is_node_in_sparse_parameters = True
node_weight = torch.zeros_like(p)
node_weight.copy_(p)
# Need to handle the concat for layers with different R & S
shape = node_weight.shape
# 1d-tensor
if len(shape) == 1:
node_weight = node_weight.view(1, shape[0])
# 2d-tensor (in, out)
elif len(shape) == 2:
node_weight = node_weight.view(shape[0], shape[1])
# 3d-tensor (batch, in, out)
elif len(shape) == 3:
node_weight = node_weight.view(shape[0]*shape[1], shape[2])
# 4d-tensor (in, out, h, w)
elif len(shape) == 4:
# convs
node_weight = node_weight.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1])
if is_node_in_sparse_parameters == False:
print("[search_for_good_permutation] cannot find the node: \'{:}\' in cls.__sparse_parameters, no need to merge its weight for permutation.".format(node_name))
else:
if matrix_group == None:
matrix_group = node_weight
else:
try:
if matrix_group.dim() == node_weight.dim():
matrix_group = torch.cat((matrix_group, node_weight), dim=0) # concat the weights in K dimension, and keep the same C dimension
else: # e.g. when try to merge the Conv and FC layers
print("[search_for_good_permutation] matrix_group dim: {:} is not matched with node_weight dim: {:}.".format(matrix_group.dim(), node_weight.dim()))
print("[search_for_good_permutation] matrix_group shape: \'{:}\' is not matched with node_weight shape: \'{:}\'.".format(matrix_group.size(), node_weight.size()))
if matrix_group.dim() < node_weight.dim():
while node_weight.dim() - matrix_group.dim() > 0:
matrix_group = matrix_group.unsqueeze(matrix_group.dim())
else:
while matrix_group.dim() - node_weight.dim() > 0:
node_weight = node_weight.unsqueeze(node_weight.dim())
print("[search_for_good_permutation] matrix_group shape: \'{:}\' is now matched with node_weight shape: \'{:}\'.".format(matrix_group.size(), node_weight.size()))
matrix_group = torch.cat((matrix_group, node_weight), dim=0) # concat the weights in K dimension, and keep the same C dimension
except:
print("[search_for_good_permutation][warning] cannot merge the weight for node: \'{:}\', with its weight shape: \'{:}\', the matrix_group shape: \'{:}\'.".format(node_name, node_weight.size(), matrix_group.size()))
continue
print("[search_for_good_permutation] have merged the weight for node: \'{:}\', with its weight shape: \'{:}\', the matrix_group shape: \'{:}\'.".format(node_name, node_weight.size(), matrix_group.size()))
if matrix_group == None: # cannot find the node: \'{:}\' in cls.__sparse_parameters
input_channel_num = 0
print("\n[search_for_good_permutation] init the all-zero list with length \'{:}\' for permutation search sequence of this unique_siblings_group.".format(input_channel_num))
print("[search_for_good_permutation] no need to search the permutation_sequence for empty matrix_group.")
permutation_sequence = [0 for n in range(input_channel_num)]
unique_siblings_groups_permutation_sequence.append(permutation_sequence)
continue
else:
input_channel_num = matrix_group.size()[1]
print("\n[search_for_good_permutation] init the all-zero list with length \'{:}\' for permutation search sequence of this unique_siblings_group.".format(input_channel_num))
permutation_sequence = [0 for n in range(input_channel_num)]
# automatic check for skipping the permutation search process
original_magnitude = (torch.abs(matrix_group)).sum(dtype=torch.float64)
pruned_magnitude = sum_after_2_to_4(matrix_group.cpu().detach().numpy())
diff_ratio = abs(original_magnitude - pruned_magnitude)/original_magnitude
epsilon = 1e-3
print("\n[search_for_good_permutation] Original element abs sum: {:}, Pruned element abs sum: {:}, Diff ratio: {:}".format(original_magnitude, pruned_magnitude, diff_ratio))
if diff_ratio < epsilon:
print("[search_for_good_permutation] Original element abs sum is almost same as the pruned element abs sum, further permutation search will not help, skipping!")
print("[search_for_good_permutation] Change the all-zero permutation search sequence to a sequential permutation search sequence.")
permutation_sequence = [n for n in range(input_channel_num)]
unique_siblings_groups_permutation_sequence.append(permutation_sequence)
continue
else:
print("[search_for_good_permutation] Original element abs sum is different from the pruned element abs sum, further permutation search will help, continue with the permutation search!")
# call the permutation search CUDA kernels as ASP extension.
# users can provide prefer search strategy by providing a valid 'search_options' as a dictionary,
# or users can implement their customized 'accelerated_search_for_good_permutation' function.
search_options = {}
# No.1 Strategy: Exhaustive Search
# search_options['strategy'] = 'exhaustive'
# search_options['stripe_group_size'] = 8
# search_options['escape_attempts'] = 100
# No.2 Strategy: Progressive Channel Swap Search
# search_options['strategy'] = 'progressive channel swap'
# search_options['progressive_search_time_limit'] = 10
# search_options['improvement_threshold'] = 1e-9
# No.3 Strategy: User Defined Search
# search_options['strategy'] = 'user defined'
# permutation search time is too long for matrix_group with large channel num
# change from Exhaustive Search to Progressive Channel Swap Search based on input matrix_group size
if input_channel_num > 2048:
search_options['strategy'] = 'progressive channel swap'
search_options['progressive_search_time_limit'] = 120
search_options['improvement_threshold'] = 1e-9
print("[search_for_good_permutation] Change to Progressive Channel Swap Search with {} seconds limitation, because the {} is too large and will leading too long permutation search time with Exhaustive Search.".format(search_options['progressive_search_time_limit'], input_channel_num))
start_time_accelerated_search_for_good_permutation = time.perf_counter()
permutation_sequence = accelerated_search_for_good_permutation(matrix_group, options=search_options)
duration_accelerated_search_for_good_permutation = time.perf_counter() - start_time_accelerated_search_for_good_permutation
print("[search_for_good_permutation] Take {:.4f} seconds to finish accelerated_search_for_good_permutation function.".format(duration_accelerated_search_for_good_permutation))
unique_siblings_groups_permutation_sequence.append(permutation_sequence)
fx_graph['unique_siblings']['permutation_sequence'] = unique_siblings_groups_permutation_sequence
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_search_for_good_permutation.json')) # save the intermediate graph as JSON file for debugging
return fx_graph
@classmethod
def init_permutation_flag(cls, fx_graph):
"""This function is used to init the permutation flag for each node according to the whole network graph built with Torch.FX."""
print("\n[init_permutation_flag] Init the permutation flag for each node according to the whole network graph built with Torch.FX")
sparse_module_names = []
processed_sparse_module_names = [] # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
sparse_module_names.append(module_name)
processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower()
processed_sparse_module_names.append(processed_module_name)
for node_name in fx_graph.keys():
processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower()
distributed_node_name = 'module.' + node_name
processed_distributed_node_name = 'module.' + processed_node_name
node_module_type = fx_graph.get(node_name).get('module_type')
if node_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']:
node_parents = fx_graph.get(node_name).get('parents')
node_children = fx_graph.get(node_name).get('children')
node_real_parents = fx_graph.get(node_name).get('real_parents')
node_real_children = fx_graph.get(node_name).get('real_children')
node_groups_param = fx_graph.get(node_name).get('groups_param')
is_node_real_children_in_sparse_parameters = False
is_node_real_children_has_group_conv = False
for real_child_item in node_real_children:
processed_real_child_item = ''.join(c for c in real_child_item if c not in string.punctuation).lower()
distributed_real_child_item = 'module.' + real_child_item
processed_distributed_real_child_item = 'module.' + processed_real_child_item
if (real_child_item in sparse_module_names) or (processed_real_child_item in processed_sparse_module_names) or (distributed_real_child_item in sparse_module_names) or (processed_distributed_real_child_item in processed_sparse_module_names):
is_node_real_children_in_sparse_parameters = True
if (fx_graph.get(real_child_item).get('groups_param') not in ['None', '1']):
is_node_real_children_has_group_conv = True
is_node_real_parents_has_group_conv = False
for real_parent_item in node_real_parents:
# notice: we assume the if one item of real_parents need to permute in C or K dim, then the corresponding flag should be set
# so for all items of real_parents, they may not share the same 'permutation_type' (e.g., one item is Group Conv, etc.)
# that's why we also need to judge the 'is_node_real_parents_has_group_conv'
if (fx_graph.get(real_parent_item).get('groups_param') not in ['None', '1']):
is_node_real_parents_has_group_conv = True
# If the node itself is in sparse_module_names or one of its real_children in sparse_module_names, then it may need the offline permutation
if ((node_name in sparse_module_names) or (processed_node_name in processed_sparse_module_names) or (distributed_node_name in sparse_module_names) or (processed_distributed_node_name in processed_sparse_module_names)) or (is_node_real_children_in_sparse_parameters == True):
if node_groups_param not in ['None', '1']:
# for Group Conv, disable the permutation in 'C' and 'K' dim
fx_graph[node_name]['permutation_type'] = 'None'
elif ('x' in node_parents) or ((node_name not in sparse_module_names) and (processed_node_name not in processed_sparse_module_names) and (distributed_node_name not in sparse_module_names) and (processed_distributed_node_name not in processed_sparse_module_names)):
# for the first (due to it is connected to 'x' node or itself is not in sparse_module_names) or not NVIDIA's TC compatiable Conv/FC, only permutate the K direction
if is_node_real_children_has_group_conv == False:
fx_graph[node_name]['permutation_type'] = 'K'
fx_graph[node_name]['k_permuted'] = 'False'
else: # if node real_children contains Group Conv, disable the permutation for node in 'K' dim
fx_graph[node_name]['permutation_type'] = 'None'
elif ('output' in node_children) or (is_node_real_children_in_sparse_parameters == False):
# for the last (due to it is connected to 'output' node or to a node which is not in sparse_module_names) FC/Conv, only permutate the C direction
if is_node_real_parents_has_group_conv == False:
fx_graph[node_name]['permutation_type'] = 'C'
fx_graph[node_name]['c_permuted'] = 'False'
else: # if node real_parents contains Group Conv, disable the permutation for node in 'C' dim
fx_graph[node_name]['permutation_type'] = 'None'
else:
if (is_node_real_parents_has_group_conv == False) and (is_node_real_children_has_group_conv == False):
fx_graph[node_name]['permutation_type'] = 'KC'
fx_graph[node_name]['k_permuted'] = 'False'
fx_graph[node_name]['c_permuted'] = 'False'
elif is_node_real_parents_has_group_conv == True: # TODO: if node real_parents contains Group Conv, disable the permutation for node in 'C' dim
fx_graph[node_name]['permutation_type'] = 'K'
fx_graph[node_name]['k_permuted'] = 'False'
else: # if node real_children contains Group Conv, disable the permutation for node in 'K' dim
fx_graph[node_name]['permutation_type'] = 'C'
fx_graph[node_name]['c_permuted'] = 'False'
else:
fx_graph[node_name]['permutation_type'] = 'None'
elif node_module_type in ['torch.nn.modules.batchnorm.BatchNorm2d']:
node_real_parents = fx_graph.get(node_name).get('real_parents')
is_node_real_parents_need_K_permutation = False
is_node_real_parents_has_group_conv = False
for real_parent_item in node_real_parents:
# notice: we assume the if one item of real_parents need to permute in K dim, then the corresponding flag should be set
# as in most of the cases, BN only follows one Conv, so it should be OK for now
if fx_graph.get(real_parent_item).get('permutation_type') in ['K', 'KC']:
is_node_real_parents_need_K_permutation = True
if (fx_graph.get(real_parent_item).get('groups_param') not in ['None', '1']):
is_node_real_parents_has_group_conv = True
node_real_children = fx_graph.get(node_name).get('real_children')
is_node_real_children_in_sparse_parameters = False
for real_child_item in node_real_children:
processed_real_child_item = ''.join(c for c in real_child_item if c not in string.punctuation).lower()
distributed_real_child_item = 'module.' + real_child_item
processed_distributed_real_child_item = 'module.' + processed_real_child_item
if (real_child_item in sparse_module_names) or (processed_real_child_item in processed_sparse_module_names) or (distributed_real_child_item in sparse_module_names) or (processed_distributed_real_child_item in processed_sparse_module_names):
is_node_real_children_in_sparse_parameters = True
# Firstly, we should make sure the BN is not in the last (due to it is connected to a FC/Conv node which is not in sparse_module_names), then:
# If the real_parents of BN node are in sparse_module_names, then it may need the offline permutation
# Or if the real_parents of BN node just needs to permute in K dim
if (is_node_real_children_in_sparse_parameters == True) and (is_node_real_parents_need_K_permutation == True):
if (is_node_real_parents_has_group_conv == False) and (is_node_real_parents_need_K_permutation == True):
fx_graph[node_name]['permutation_type'] = 'K'
fx_graph[node_name]['k_permuted'] = 'False'
else: # if node real_parents contains Group Conv or does not need permutation in 'K' dim, disable the permutation for node in 'K' dim
fx_graph[node_name]['permutation_type'] = 'None'
else:
fx_graph[node_name]['permutation_type'] = 'None'
else:
fx_graph[node_name]['permutation_type'] = 'None'
# A special case: if the node itself not in sparse_module_names but one of its real_siblings in sparse_module_names, then the node will not do the permutation search, but it may need to apply the offline permutation in C dim according to the searched permutation sequence from its real_siblings in sparse_module_names
# We make it as the post-processing, because if we add this to the previous logic, will make it too complex
# Post-processing Step No.1:
print("\n[init_permutation_flag] Post-processing Step No.1.")
node_change_permutation_due_to_siblings = []
for node_name in fx_graph.keys():
node_real_siblings = fx_graph.get(node_name).get('real_siblings')
if node_real_siblings is not None:
is_node_real_siblings_needs_C_permutation = False
for real_sibling_item in node_real_siblings:
if fx_graph.get(real_sibling_item).get('permutation_type') in ['C', 'KC']:
is_node_real_siblings_needs_C_permutation = True
if is_node_real_siblings_needs_C_permutation == True:
print("[init_permutation_flag] node_name: \'{:}\', one of its real siblings need do offline permutation in C dim.".format(node_name))
node_original_permutation_type = fx_graph.get(node_name).get('permutation_type')
if node_original_permutation_type in ['C', 'KC']:
print("[init_permutation_flag] node_name: \'{:}\', its original permutation: \'{:}\' already includes C dim, no need to do No.1 post-processing change.".format(node_name, node_original_permutation_type))
elif node_original_permutation_type == 'None':
fx_graph[node_name]['permutation_type'] = 'C'
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'C'.".format(node_name, node_original_permutation_type))
node_change_permutation_due_to_siblings.append(node_name)
elif node_original_permutation_type == 'K':
fx_graph[node_name]['permutation_type'] = 'KC'
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'KC'.".format(node_name, node_original_permutation_type))
node_change_permutation_due_to_siblings.append(node_name)
# Post-processing Step No.2:
print("\n[init_permutation_flag] Post-processing Step No.2.")
for node_name in fx_graph.keys():
node_real_children = fx_graph.get(node_name).get('real_children')
node_module_type = fx_graph.get(node_name).get('module_type')
if (node_real_children is not None) and (node_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear', 'torch.nn.modules.batchnorm.BatchNorm2d']):
is_node_real_children_has_node_change_permutation = False
for real_child_item in node_real_children:
if real_child_item in node_change_permutation_due_to_siblings:
is_node_real_children_has_node_change_permutation = True
if is_node_real_children_has_node_change_permutation == True:
print("[init_permutation_flag] node_name: \'{:}\', one of its real children has changed permutation due to its siblings.".format(node_name))
node_original_permutation_type = fx_graph.get(node_name).get('permutation_type')
if node_original_permutation_type in ['K', 'KC']:
print("[init_permutation_flag] node_name: \'{:}\', its original permutation: \'{:}\' already includes K dim, no need to do No.2 post-processing change.".format(node_name, node_original_permutation_type))
elif node_original_permutation_type == 'None':
fx_graph[node_name]['permutation_type'] = 'K'
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'K'.".format(node_name, node_original_permutation_type))
elif node_original_permutation_type == 'C':
fx_graph[node_name]['permutation_type'] = 'KC'
print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'KC'.".format(node_name, node_original_permutation_type))
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_init_permutation_flag.json')) # save the intermediate graph as JSON file for debugging
return fx_graph
@classmethod
def extract_all_unique_siblings(cls, fx_graph):
"""This function is used to extrat all unique siblings for the whole network graph built with Torch.FX."""
print("\n[extract_all_unique_siblings] Extract all unique siblings for the whole network graph built with Torch.FX")
all_unique_siblings_name = []
all_unique_siblings_module_type = []
for node_name in fx_graph.keys():
fx_graph[node_name]['node_type'] = 'network_node' # use the 'node_type' to divide the real nodes apart from the auxiliary info node, like 'unique_siblings' node
node_module_type = fx_graph.get(node_name).get('module_type')
node_real_siblings = fx_graph.get(node_name).get('real_siblings')
node_real_siblings_module_type = fx_graph.get(node_name).get('real_siblings_module_type')
if node_real_siblings == []:
print("[extract_all_unique_siblings] node_name: \'{:}\', node module type: \'{:}\', has no real siblings.".format(node_name, node_module_type))
# for the Conv/FC layers without real_siblings, then we should insert itself as an unique_siblings
if node_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']:
# direct insert will change the real_siblings info for the node in the fx_graph
node_real_siblings_with_node_itself = node_real_siblings.copy()
node_real_siblings_with_node_itself.insert(0, node_name)
node_real_siblings_module_type_with_node_itself = node_real_siblings_module_type.copy()
node_real_siblings_module_type_with_node_itself.insert(0, node_module_type)
all_unique_siblings_name.append(node_real_siblings_with_node_itself)
all_unique_siblings_module_type.append(node_real_siblings_module_type_with_node_itself)
else:
print("[extract_all_unique_siblings] node_name: \'{:}\', node module type: \'{:}\', has {:} real siblings: \'{:}\'.".format(node_name, node_module_type, len(node_real_siblings), node_real_siblings))
# for the two duplicated siblings lists, the node names included should be the same.
# If the node name is already included in one of the unique_siblings_name list, which means the real_siblings of this node is duplicated with the unique_siblings_name list.
# Otherwise, we should insert the [real_siblings + node_name] as a new unique_siblings_name list.
has_include_siblings = False
for unique_siblings_item in all_unique_siblings_name:
if node_name in unique_siblings_item:
has_include_siblings = True
if has_include_siblings == False:
# direct insert will change the real_siblings info for the node in the fx_graph
node_real_siblings_with_node_itself = node_real_siblings.copy()
node_real_siblings_with_node_itself.insert(0, node_name)
node_real_siblings_module_type_with_node_itself = node_real_siblings_module_type.copy()
node_real_siblings_module_type_with_node_itself.insert(0, node_module_type)
all_unique_siblings_name.append(node_real_siblings_with_node_itself)
all_unique_siblings_module_type.append(node_real_siblings_module_type_with_node_itself)
fx_graph['unique_siblings'] = {}
fx_graph['unique_siblings']['name'] = all_unique_siblings_name
fx_graph['unique_siblings']['module_type'] = all_unique_siblings_module_type
fx_graph['unique_siblings']['node_type'] = 'auxiliary_info_node'
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_extract_all_unique_siblings.json')) # save the intermediate graph as JSON file for debugging
return fx_graph
@classmethod
def find_real_siblings(cls, fx_graph):
"""This function is used to find all siblings for each node according to the whole network graph built with Torch.FX.
we need to find siblings recursively, because siblings may have siblings via other parents we don't know about.
"""
print("\n[find_real_siblings] Find all siblings for each node according to the whole network graph built with Torch.FX")
for node_name in fx_graph.keys():
node_real_siblings_name = []
node_real_siblings_module_type = []
node_real_parents = fx_graph.get(node_name).get('real_parents')
node_module_type = fx_graph.get(node_name).get('module_type')
if node_module_type not in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']:
print("[find_real_siblings] node_name: \'{:}\', node module type: \'{:}\', has no real siblings.".format(node_name, node_module_type))
else:
print("[find_real_siblings] node_name: \'{:}\', node module type: \'{:}\', may have real siblings.".format(node_name, node_module_type))
# sibling means the nodes share the same real parent
for real_parent_item in node_real_parents:
for real_child_item in fx_graph.get(real_parent_item).get('real_children'):
if real_child_item != node_name:
sibling_module_type = fx_graph.get(real_child_item).get('module_type')
print("[find_real_siblings] node_name: \'{:}\', has one real sibling: \'{:}\', its real sibling module type: \'{:}\'.".format(node_name, real_child_item, sibling_module_type))
node_real_siblings_name.append(real_child_item)
node_real_siblings_module_type.append(sibling_module_type)
# remove the duplicated real siblings
exclusive_node_real_siblings_name = []
exclusive_node_real_siblings_module_type = []
item_index = 0
duplicated_real_siblings = 0
for item in node_real_siblings_name:
if item not in exclusive_node_real_siblings_name:
exclusive_node_real_siblings_name.append(item)
exclusive_node_real_siblings_module_type.append(node_real_siblings_module_type[item_index])
else:
duplicated_real_siblings = duplicated_real_siblings + 1
item_index = item_index + 1
if duplicated_real_siblings > 0:
print("[find_real_siblings] node_name: \'{:}\', remove {:} duplicated real siblings.".format(node_name, duplicated_real_siblings))
fx_graph[node_name]['real_siblings'] = exclusive_node_real_siblings_name
fx_graph[node_name]['real_siblings_module_type'] = exclusive_node_real_siblings_module_type
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_find_real_siblings.json')) # save the intermediate graph as JSON file for debugging
return fx_graph
@classmethod
def recursive_find_real_children(cls, node_name, fx_graph):
"""This function is used to recursively find the real children for each node according to the whole network graph built with Torch.FX.
Used as the sub-function of find_real_children.
"""
node_real_children_name = []
node_real_children_module_type = []
if node_name in fx_graph.keys(): # can be deleted, because node_name is already in the 'children' item in one node of the fx_graph
node_children = fx_graph.get(node_name).get('children')
node_module_type = fx_graph.get(node_name).get('module_type')
has_visit_children_num = 0
has_real_children_num = 0
sub_node_need_recursive_search = []
while has_visit_children_num < len(node_children):
for child_name in node_children:
if child_name != 'output': # 'output' node has no 'module_type'
child_module_type = fx_graph.get(child_name).get('module_type')
if child_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']:
print("[recursive_find_real_children] node_name: \'{:}\', has one real child: \'{:}\', its real child module type: \'{:}\'.".format(node_name, child_name, child_module_type))
node_real_children_name.append(child_name)
node_real_children_module_type.append(child_module_type)
has_real_children_num = has_real_children_num + 1
else:
print("[recursive_find_real_children] node_name: \'{:}\', its child: \'{:}\' with module type: \'{:}\', needs recursive search.".format(node_name, child_name, child_module_type))
sub_node_need_recursive_search.append(child_name)
else:
print("[recursive_find_real_children] node_name: \'{:}\', its child: \'{:}\' with no module type, is not its real child.".format(node_name, child_name))
has_visit_children_num = has_visit_children_num + 1
if len(sub_node_need_recursive_search) > 0:
for sub_node in sub_node_need_recursive_search:
if fx_graph.get(sub_node).get('real_children') == []:
sub_node_real_children_name, sub_node_real_children_module_type = cls.recursive_find_real_children(sub_node, fx_graph)
else:
# if the sub_node already find the 'real_children', no need to do recursive search
sub_node_real_children_name = fx_graph.get(sub_node).get('real_children')
sub_node_real_children_module_type = fx_graph.get(sub_node).get('real_children_module_type')
node_real_children_name.extend(sub_node_real_children_name)
node_real_children_module_type.extend(sub_node_real_children_module_type)
return node_real_children_name, node_real_children_module_type
@classmethod
def find_real_children(cls, fx_graph):
"""This function is used to find the real children for each node according to the whole network graph built with Torch.FX.
For example:
The real children of Conv is the subsequent Conv/FC.
The real children of BN or other no-need-permutataion layers is the subsequent Conv/FC.
"""
print("\n[find_real_children] Find the real children for each node according to the whole network graph built with Torch.FX")
from sys import version_info
if version_info.major == 3 and version_info.minor >= 8:
reversible_fx_graph_keys = fx_graph.keys()
else: # 'dict_keys' object is not reversible in previous of Python 3.8
reversible_fx_graph_keys = list(fx_graph.keys())
for node_name in reversed(reversible_fx_graph_keys): # as the optimization, we need to find the real children from back to front, to use the already saved 'real_children'
node_real_children_name = []
node_real_children_module_type = []
node_children = fx_graph.get(node_name).get('children')
node_module_type = fx_graph.get(node_name).get('module_type')
if node_module_type not in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']:
print("\n[find_real_children] node_name: \'{:}\', node module type: \'{:}\', children num: {:}, recursive to find real children.".format(node_name, node_module_type, len(node_children)))
node_real_children_name, node_real_children_module_type = cls.recursive_find_real_children(node_name, fx_graph)
else: # Quick method, but cannot get the real children for no-need-permutataion layers like BN
print("\n[find_real_children] node_name: \'{:}\', node module type: \'{:}\', children num: {:}, can directly find real children.".format(node_name, node_module_type, len(node_children)))
# if the node is in the 'real_parents' list of the other node, then the other node is the real children for this node
for other_node_name in fx_graph.keys():
if (other_node_name != node_name) and (node_name in fx_graph.get(other_node_name).get('real_parents')):
child_module_type = fx_graph.get(other_node_name).get('module_type')
if child_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']:
print("[find_real_children] node_name: \'{:}\', has one real child: \'{:}\', its real child module type: \'{:}\'.".format(node_name, other_node_name, child_module_type))
node_real_children_name.append(other_node_name)
node_real_children_module_type.append(child_module_type)
# remove the duplicated real children
exclusive_node_real_children_name = []
exclusive_node_real_children_module_type = []
item_index = 0
duplicated_real_children = 0
for item in node_real_children_name:
if item not in exclusive_node_real_children_name:
exclusive_node_real_children_name.append(item)
exclusive_node_real_children_module_type.append(node_real_children_module_type[item_index])
else:
duplicated_real_children = duplicated_real_children + 1
item_index = item_index + 1
if duplicated_real_children > 0:
print("[find_real_children] node_name: \'{:}\', remove {:} duplicated real children.".format(node_name, duplicated_real_children))
fx_graph[node_name]['real_children'] = exclusive_node_real_children_name
fx_graph[node_name]['real_children_module_type'] = exclusive_node_real_children_module_type
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_find_real_children.json')) # save the intermediate graph as JSON file for debugging
return fx_graph
@classmethod
def find_real_parents(cls, fx_graph):
"""This function is used to find the real parents for each node according to the whole network graph built with Torch.FX.
For example:
The real parent of BN is the previous Conv/FC.
The real parent of Conv is the previous Conv/FC.
"""
print("\n[find_real_parents] Find the real parents for each node according to the whole network graph built with Torch.FX")
for node_name in fx_graph.keys():
node_real_parents_name = []
node_real_parents_module_type = []
node_parents = fx_graph.get(node_name).get('parents')
print("[find_real_parents] node_name: \'{:}\', parents num: {:}".format(node_name, len(node_parents)))
has_visit_parent_num = 0
while has_visit_parent_num < len(node_parents):
for parent_name in node_parents:
if fx_graph.__contains__(parent_name):
parent_module_type = fx_graph.get(parent_name).get('module_type')
if parent_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']:
print("[find_real_parents] node_name: \'{:}\', has one real parent: \'{:}\', its real parent module type: \'{:}\'.".format(node_name, parent_name, parent_module_type))
node_real_parents_name.append(parent_name)
node_real_parents_module_type.append(parent_module_type)
else:
print("[find_real_parents] node_name: \'{:}\', has one/several real parent(s): \'{:}\', its real parent module type: \'{:}\'.".format(node_name, fx_graph[parent_name]['real_parents'], fx_graph[parent_name]['real_parents_module_type']))
for real_parent_item in fx_graph[parent_name]['real_parents']:
node_real_parents_name.append(real_parent_item)
for real_parent_module_type_item in fx_graph[parent_name]['real_parents_module_type']:
node_real_parents_module_type.append(real_parent_module_type_item)
else:
print("[find_real_parents] node_name: \'{:}\', has no real parent because this is the first node.".format(node_name))
has_visit_parent_num = has_visit_parent_num + 1
# remove the duplicated real parents
exclusive_node_real_parents_name = []
exclusive_node_real_parents_module_type = []
exclusive_node_real_parents_groups_param = []
item_index = 0
duplicated_real_parents = 0
for item in node_real_parents_name:
if item not in exclusive_node_real_parents_name:
exclusive_node_real_parents_name.append(item)
exclusive_node_real_parents_module_type.append(node_real_parents_module_type[item_index])
exclusive_node_real_parents_groups_param.append(fx_graph.get(item).get('groups_param'))
else:
duplicated_real_parents = duplicated_real_parents + 1
item_index = item_index + 1
if duplicated_real_parents > 0:
print("[find_real_parents] node_name: \'{:}\', remove {:} duplicated real parents.".format(node_name, duplicated_real_parents))
fx_graph[node_name]['real_parents'] = exclusive_node_real_parents_name
fx_graph[node_name]['real_parents_module_type'] = exclusive_node_real_parents_module_type
fx_graph[node_name]['real_parents_groups_param'] = exclusive_node_real_parents_groups_param
if cls.__save_permutation_graph:
cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_find_real_parent.json')) # save the intermediate graph as JSON file for debugging
return fx_graph
@classmethod
def build_fx_graph(cls, model, dump_fx_graph=False, save_dumped_fx_graph='./model_fx_graph.json'):
"""This function is used to build the whole network graph with Torch.FX features."""
success = True
torch_version = str(torch.__version__)
torch_version_major = int(torch_version.split('.')[0])
torch_version_minor = int(torch_version.split('.')[1])
try:
torch_version_minimum = int(torch_version.split('.')[2])
except ValueError: # support the none standard version
torch_version_minimum = torch_version.split('.')[2]
print("[build_fx_graph] The torch version is: {}, version major is: {}, version minor is: {}, version minimum is: {}".format(torch_version, torch_version_major, torch_version_minor, torch_version_minimum))
if torch_version_major >= 1 and torch_version_minor >= 8:
print("[build_fx_graph] The Torch.FX is supported.")
else: # Torch.FX is introduced in torch 1.8.0
print("[build_fx_graph] The Torch.FX is not supported. So cannot build the Torch.FX graph.")
success = False
network_fx_graph = {}
return network_fx_graph, success
print("\n[build_fx_graph] Print the model structure with pure PyTorch function")
print(model)
print("\n[build_fx_graph] Build the module name and type dictionary")
module_name_type_dict = {}
module_name_group_conv_dict = {}
for name, mod in model.named_modules():
print("[build_fx_graph] module_name: {}, module type: {}".format(name, type(mod)))
module_name_type_dict[name] = str(type(mod)).split("\'")[1]
try:
print("[build_fx_graph] this module has \'group\' param with value: {}".format(mod.groups))
module_name_group_conv_dict[name] = str(mod.groups)
except:
module_name_group_conv_dict[name] = 'None'
continue
graph_module = cls.print_raw_fx_graph(model, print_tabular=True)
# keep track of children and parents for each layer (could be call_module or call_function)
print("\n[build_fx_graph] Print the children and parents relationship for each layer")
network_fx_graph = {}
for node in graph_module.graph.nodes:
if node.op == 'placeholder':
print("[build_fx_graph] This is the \'input\' node: {:}".format(node.target))
continue
elif node.op == 'get_attr':
print("[build_fx_graph] This is the \'get_attr\' node: {:}".format(node.target))
continue
elif node.op == 'call_function': # e.g. 'adaptive.avg.pool2d', 'add', 'cat', 'flatten', 'floordiv', 'getattr', 'getitem', 'hardsigmoid', 'mean', 'mul', 'relu', 'transpose'
node_parent, node_children = get_node_parent_children(node)
converted_node_name=convert_fx_node_name(node.name)
print("[build_fx_graph] This is the \'call_function\' node: {:}, its parent list: {:}, its children list: {:}".format(converted_node_name, node_parent, node_children))
network_fx_graph[converted_node_name] = {}
network_fx_graph[converted_node_name]['parents'] = node_parent
network_fx_graph[converted_node_name]['children'] = node_children
network_fx_graph[converted_node_name]['fx_op'] = 'call_function'
elif node.op == 'call_method': # e.g. 'chunk', 'contiguous', 'mean', 'size', 'unsqueeze', 'view'
node_parent, node_children = get_node_parent_children(node)
converted_node_name=convert_fx_node_name(node.name)
print("[build_fx_graph] This is the \'call_method\' node: {:}, its parent list: {:}, its children list: {:}".format(converted_node_name, node_parent, node_children))
network_fx_graph[converted_node_name] = {}
network_fx_graph[converted_node_name]['parents'] = node_parent
network_fx_graph[converted_node_name]['children'] = node_children
network_fx_graph[converted_node_name]['fx_op'] = 'call_method'
continue
elif node.op == 'call_module':
node_parent, node_children = get_node_parent_children(node)
converted_node_name=convert_fx_node_name(node.name)
# check whether the converted_node_name is same as node.target, especially for ReLU case
if converted_node_name != node.target:
print("[build_fx_graph][warning] The target name from Torch.FX is \'{:}\', the manually converted node name is \'{:}\', not the same one, choose the converted node name".format(node.target, converted_node_name))
# assume the modules share the same target name have the same type, because converted_node_name may not be obtained by model.named_modules(), like some ReLU (defined in forward function)
node_type = module_name_type_dict[node.target]
print("[build_fx_graph] This is the \'call_module\' node: {:}, its parent list: {:}, its children list: {:}, its type: {:}".format(converted_node_name, node_parent, node_children, node_type))
network_fx_graph[converted_node_name] = {}
network_fx_graph[converted_node_name]['parents'] = node_parent
network_fx_graph[converted_node_name]['children'] = node_children
network_fx_graph[converted_node_name]['fx_op'] = 'call_module'
network_fx_graph[converted_node_name]['module_type'] = node_type
network_fx_graph[converted_node_name]['groups_param'] = module_name_group_conv_dict[node.target]
elif node.op == 'output':
print("[build_fx_graph] This is the \'output\' node: {:}".format(node.target))
continue
if dump_fx_graph:
print("\n[build_fx_graph] Dump the overall dict for children and parents relationship into JSON file")
cls.save_graph_to_json(network_fx_graph, save_dumped_graph_path_with_name=save_dumped_fx_graph)
return network_fx_graph, success
@classmethod
def print_raw_fx_graph(cls, model, print_tabular=False, generate_python_code=False):
"""This function is used to print the intermediate representation (IR) - Graph representation with Torch.FX features."""
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
try:
symbolic_traced : torch.fx.GraphModule = symbolic_trace(model)
except:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print("\n[print_raw_fx_graph] Meet the fatal fault when trying to symbolic trace the model with Torch.FX")
raise
exit(0)
# High-level intermediate representation (IR) - Graph representation
print("\n[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX")
print(symbolic_traced.graph)
if print_tabular:
print("\n[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX in a table format")
try:
symbolic_traced.graph.print_tabular()
except AttributeError: # to avoid the AttributeError: 'Graph' object has no attribute 'print_tabular'
print("[print_raw_fx_graph][Warning] \'print_tabular\' function is not supported in current Torch version. Skip!")
# Code generation - valid Python code
if generate_python_code:
print("\n[print_raw_fx_graph] Create valid Python code matching the IR/Graph's semantics with Torch.FX")
print(symbolic_traced.code)
return symbolic_traced
@classmethod
def save_graph_to_json(cls, graph, save_dumped_graph_path_with_name='./model_fx_graph.json'):
"""This function is used to same the graph into JSON file."""
# use dumps to transfer the dict to JSON string
json_graph_str = json.dumps(graph)
with open(save_dumped_graph_path_with_name, 'w', encoding='utf-8') as dumped_graph_file:
dumped_graph_file.write(json_graph_str) # write the transferred JSON string into JSON file
#include <stdio.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
namespace py = pybind11;
#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
if (code != cudaSuccess)
{
fprintf(stderr,"GPUassert %d: %s %s %d\n", (int)code, cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}
__device__ float group_2_to_4(float4 vals)
{
vals.x = fabs(vals.x);
vals.y = fabs(vals.y);
vals.z = fabs(vals.z);
vals.w = fabs(vals.w);
float sum0 = vals.x + vals.y;
float sum1 = vals.x + vals.z;
float sum2 = vals.x + vals.w;
float sum3 = vals.y + vals.z;
float sum4 = vals.y + vals.w;
float sum5 = vals.z + vals.w;
float best_sum0 = fmax(sum0, sum1);
float best_sum1 = fmax(sum2, sum3);
float best_sum2 = fmax(sum4, sum5);
float best_sum = fmax(fmax(best_sum0, best_sum1), best_sum2);
return best_sum;
}
inline float* float_ptr_from_numpy(py::array_t<float>& py_float)
{
return (float*)py_float.data();
}
inline unsigned int* uint_ptr_from_numpy(py::array_t<unsigned int>& py_uint)
{
return (unsigned int*)py_uint.data();
}
__global__ void subset_sum_after_2_to_4(float* matrix,
unsigned int rows,
unsigned int cols,
unsigned int start_col,
unsigned int end_col,
float* output)
{
// vectorize
float4* mat4 = (float4*) matrix;
cols /= 4;
start_col /= 4;
end_col /= 4;
// each thread in a block takes some number of rows
size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);
size_t row_offset = num_rows * threadIdx.x;
// each block takes some number of columns
size_t num_cols = (end_col - start_col) / gridDim.x;
size_t col_offset = num_cols * blockIdx.x;
start_col += col_offset;
end_col = start_col + num_cols;
float sum = 0.0f;
for ( unsigned int r = row_offset; r < row_offset + num_rows; ++r ) {
if (r < rows) {
for ( unsigned int c = start_col; c < end_col; c++ ) {
sum += group_2_to_4(mat4[r * cols + c]);
}
}
}
atomicAdd(output, sum);
}
// build the entire permute map at once
// each block handles one group of stripes
// each threads in the block handle all handle the same permutation at the same time on different rows before moving to the next permutation
__global__ void build_permute_map(float* matrix,
unsigned int rows,
unsigned int cols,
unsigned int* stripes,
unsigned int group_width,
unsigned int* permutations,
unsigned int num_permutations,
unsigned int perm_length,
float* output,
unsigned int* best_indices)
{
// vectorize
float4* mat4 = (float4*) matrix;
cols /= 4;
// each block handles a group of stripes
unsigned int* stripe_group = (unsigned int*)&stripes[blockIdx.x*group_width];
// shared memory: 32 threads each need 16*2
extern __shared__ float pm_shared[32][32];
float4* local_stripes = (float4*)&pm_shared[threadIdx.x];
float* local_columns = (float*) &pm_shared[threadIdx.x];
float4* permuted_stripes = (float4*) &local_stripes[4];
float* permuted_columns = (float*) &local_columns[16];
// each thread handles all permutations in the row before moving on to the next row
size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1);
size_t row_offset = num_rows * threadIdx.x;
for ( unsigned int r = row_offset; r < row_offset + num_rows; ++r) {
if (r >= rows)
break;
// load a row into smem
for ( unsigned int s = 0; s < group_width; ++s) {
unsigned int const stripe = stripe_group[s];
local_stripes[s] = mat4[r*cols+stripe];
}
for ( unsigned int p = 0; p < num_permutations; ++p) {
unsigned int* permutation = &permutations[p*perm_length];
float sum = 0.0f;
// permute
#pragma unroll 4
for ( unsigned int c = 0; c < group_width*4; ++c) {
permuted_columns[c] = local_columns[permutation[c]];
}
// sum 2:4
for ( unsigned int s = 0; s < group_width; ++s) {
sum += group_2_to_4(permuted_stripes[s]);
}
// update the running sum for this stripe group's permutation
atomicAdd(&output[blockIdx.x*num_permutations + p], sum);
}
}
// at this point, each permutation's sum in this stripe group has been calculated
// now, find the best option
__syncthreads();
if (threadIdx.x == 0) {
unsigned int best_permutation = 0;
float best_magnitude = output[blockIdx.x*num_permutations];
float base_magnitude = best_magnitude;
//#pragma unroll 32
for (unsigned int p = 1; p < num_permutations; ++p) {
float magnitude = output[blockIdx.x*num_permutations+p];
if (magnitude > best_magnitude) {
best_permutation = p;
best_magnitude = magnitude;
}
}
output[blockIdx.x*num_permutations] = best_magnitude - base_magnitude;
best_indices[blockIdx.x] = best_permutation;
}
}
void free_sum_after_2_to_4_memory(float** dmatrix,
float** dresult)
{
cudaFree(*dmatrix);
cudaFree(*dresult);
}
int set_up_sum_after_2_to_4_memory(float** dmatrix,
unsigned int rows,
unsigned int cols,
float** dresult)
{
static unsigned int setupRows = 0;
static unsigned int setupCols = 0;
static bool allocated = false;
int fresh_allocation = 0;
if (!allocated ||
setupRows != rows ||
setupCols != cols)
{
if (allocated)
free_sum_after_2_to_4_memory(dmatrix, dresult);
gpuErrchk(cudaMalloc( (void**) dmatrix, rows*cols*sizeof(float)));
gpuErrchk(cudaMalloc( (void**) dresult, sizeof(float)));
setupRows = rows;
setupCols = cols;
fresh_allocation = 1;
}
allocated = true;
return fresh_allocation;
}
int run_subset_sum_after_2_to_4(py::array_t<float>& py_matrix,
unsigned int rows,
unsigned int cols,
unsigned int start_col,
unsigned int end_col,
unsigned int blocks,
unsigned int threads,
py::array_t<float>& py_output)
{
static float* d_matrix;
static float* d_result;
int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result);
float* matrix = float_ptr_from_numpy(py_matrix);
float* output = float_ptr_from_numpy(py_output);
gpuErrchk(cudaMemcpy( d_matrix, matrix, rows*cols*sizeof(float), cudaMemcpyHostToDevice ));
gpuErrchk(cudaMemset( d_result, 0, sizeof(float)));
subset_sum_after_2_to_4<<<blocks, threads>>>(d_matrix, rows, cols, start_col, end_col, d_result);
gpuErrchk(cudaDeviceSynchronize());
gpuErrchk(cudaMemcpy( output, d_result, sizeof(float), cudaMemcpyDeviceToHost ));
return 0;
}
void set_up_permute_map_memory(float** dmatrix,
unsigned int rows,
unsigned int cols,
unsigned int** dstripes,
unsigned int num_groups,
unsigned int group_width,
unsigned int** dpermutations,
unsigned int num_permutations,
unsigned int perm_length,
float** doutput,
unsigned int** dindices,
float** hresult,
unsigned int** hindices)
{
static unsigned int setUpRows = 0;
static unsigned int setUpCols = 0;
static unsigned int setUpGroupWidth = 0;
static unsigned int setUpNumGroups = 0;
static unsigned int setUpNumPerms = 0;
static unsigned int setUpPermLength = 0;
if (setUpRows != rows ||
setUpCols != cols) {
if (*dmatrix != NULL) { gpuErrchk(cudaFree(*dmatrix)); *dmatrix = NULL; }
gpuErrchk(cudaMalloc( (void**) dmatrix, rows*cols*sizeof(float)));
}
if (setUpGroupWidth < group_width ||
setUpNumGroups < num_groups) {
if (*dstripes != NULL) { gpuErrchk(cudaFree(*dstripes)); *dstripes = NULL; }
gpuErrchk(cudaMalloc( (void**) dstripes, num_groups*group_width*sizeof(unsigned int)));
if (setUpNumGroups < num_groups) {
if (*dindices != NULL) { gpuErrchk(cudaFree(*dindices)); *dindices = NULL; }
gpuErrchk(cudaMalloc( (void**) dindices, num_groups*sizeof(unsigned int)));
if (*hindices != NULL) { free(*hindices); *hindices = NULL; }
*hindices = (unsigned int*) malloc (num_groups*sizeof(unsigned int));
}
}
if (setUpNumPerms < num_permutations ||
setUpPermLength < perm_length) {
if (*dpermutations != NULL) { gpuErrchk(cudaFree(*dpermutations)); *dpermutations = NULL; }
gpuErrchk(cudaMalloc( (void**) dpermutations, perm_length*num_permutations*sizeof(unsigned int)));
}
if (setUpNumPerms < num_permutations ||
setUpNumGroups < num_groups) {
if (*doutput != NULL) { gpuErrchk(cudaFree(*doutput)); *doutput = NULL; }
gpuErrchk(cudaMalloc( (void**) doutput, num_permutations*num_groups*sizeof(float)));
if (*hresult != NULL) { free(*hresult); *hresult = NULL; }
*hresult = (float*) malloc(num_permutations*num_groups*sizeof(float));
}
setUpRows = rows;
setUpCols = cols;
setUpGroupWidth = group_width;
setUpNumGroups = num_groups;
setUpNumPerms = num_permutations;
setUpPermLength = perm_length;
}
int run_build_permute_map(py::array_t<float>& py_matrix,
unsigned int rows,
unsigned int cols,
py::array_t<unsigned int>& py_stripes,
unsigned int num_groups,
unsigned int group_width,
py::array_t<unsigned int>& py_permutations,
//unsigned int num_permutations,
unsigned int perm_length,
py::array_t<float>& py_improvements,
py::array_t<unsigned int>& py_best_indices)
{
static float* d_matrix = NULL;
static unsigned int* d_stripes = NULL;
static unsigned int* d_permutations = NULL;
static float* d_output = NULL;
static unsigned int* d_indices = NULL;
static float* hresult = NULL;
static unsigned int* hindices = NULL;
//const unsigned int cols = py_matrix.size() / rows;
//const unsigned int num_groups = py_stripes.size() / group_width;
//const unsigned int perm_length = group_width * 4; // 2:4 sparsity - each stripe in the group is 4 elements wide
const unsigned int num_permutations = py_permutations.size() / perm_length;
const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40;
const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH;
const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH;
const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0);
set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups,MAX_GROUPS_PER_LAUNCH), group_width, &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices);
float* matrix = float_ptr_from_numpy(py_matrix);
unsigned int* stripes = uint_ptr_from_numpy(py_stripes);
unsigned int* permutations = uint_ptr_from_numpy(py_permutations);
float* improvements = float_ptr_from_numpy(py_improvements);
unsigned int* best_indices = uint_ptr_from_numpy(py_best_indices);
gpuErrchk(cudaMemcpy( d_matrix, matrix, rows*cols*sizeof(float), cudaMemcpyHostToDevice ));
gpuErrchk(cudaMemcpy( d_permutations, permutations, num_permutations*perm_length*sizeof(unsigned int), cudaMemcpyHostToDevice ));
unsigned int group_offset = 0;
for (unsigned int l = 0; l < launches; ++l)
{
unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch;
gpuErrchk(cudaMemcpy( d_stripes, &stripes[group_offset*group_width], groups_this_launch*group_width*sizeof(unsigned int), cudaMemcpyHostToDevice ));
gpuErrchk(cudaMemset( d_output, 0, groups_this_launch*num_permutations*sizeof(float)));
gpuErrchk(cudaMemset( d_indices, 0, groups_this_launch*sizeof(unsigned int)));
unsigned int shmem = 32*(32)*sizeof(float);
build_permute_map<<<groups_this_launch, 32, shmem>>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations, num_permutations, perm_length, d_output, d_indices);
gpuErrchk(cudaDeviceSynchronize());
gpuErrchk(cudaMemcpy( hresult, d_output, num_permutations*groups_this_launch*sizeof(float), cudaMemcpyDeviceToHost ));
gpuErrchk(cudaMemcpy( hindices, d_indices, groups_this_launch*sizeof(unsigned int), cudaMemcpyDeviceToHost ));
// thread0 stuck the minimum in the first slot of each group
for (unsigned int g = 0; g < groups_this_launch; ++g) {
improvements[group_offset+g] = hresult[g*num_permutations];
best_indices[group_offset+g] = hindices[g];
}
group_offset += groups_this_launch;
}
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("sum_after_2_to_4", &run_subset_sum_after_2_to_4, "matrix sum after applying 2:4 (CUDA)");
m.def("build_permute_map", &run_build_permute_map, "optimize stripe groups (CUDA)");
}
\ No newline at end of file
from .call_permutation_search_kernels import accelerated_search_for_good_permutation
from .permutation_utilities import sum_after_2_to_4
\ No newline at end of file
import numpy as np
from .permutation_utilities import *
from .exhaustive_search import Exhaustive_Search
def accelerated_search_for_good_permutation(matrix_group, options=None):
"""This function is used to call the permutation search CUDA kernels.
users can provide prefer search strategy by providing a valid 'options' as a dictionary,
or users can implement their customized 'accelerated_search_for_good_permutation' function.
"""
input_matrix = matrix_group.cpu().detach().numpy()
print("\n[accelerated_search_for_good_permutation] input matrix shape: \'{:}\'.".format(input_matrix.shape))
result = np.copy(input_matrix)
# init a sequential permutation search sequence
input_channel_num = matrix_group.size()[1]
permutation_sequence = [n for n in range(input_channel_num)]
duration = 0.0
if options == None:
options = {}
if 'strategy' not in options: # right now, the default permutation search strategy is: 'exhaustive' search
options['strategy'] = 'exhaustive'
print("[accelerated_search_for_good_permutation] the permutation strategy is: \'{:} search\'.".format(options['strategy']))
# define sub options for each search strategy
if options['strategy'] == 'exhaustive':
# right now, the default options for 'exhaustive' search is: 'exhaustive,8,100'
if 'stripe_group_size' not in options:
options['stripe_group_size'] = 8
if 'escape_attempts' not in options:
options['escape_attempts'] = 100
elif options['strategy'] == 'progressive channel swap':
# just swaps meaningful channels, keeping the good swaps, until the search time limit expires.
if 'progressive_search_time_limit' not in options:
options['progressive_search_time_limit'] = 60
if 'improvement_threshold' not in options:
options['improvement_threshold'] = 1e-9
# execute the requested strategy
if options['strategy'] == 'exhaustive':
result, duration, permutation_sequence = Exhaustive_Search(result, stripe_group_size=options['stripe_group_size'], escape_attempts=options['escape_attempts'])
elif options['strategy'] == 'progressive channel swap':
real_swap_num = 0
start_time = time.perf_counter()
while time.perf_counter() - start_time < options['progressive_search_time_limit']:
src = np.random.randint(result.shape[1])
dst = np.random.randint(result.shape[1])
src_group = int(src/4)
dst_group = int(dst/4)
if src_group == dst_group: # channel swapping within a stripe does nothing
continue
new_sum, improvement = try_swap(result, dst, src)
if improvement > options['improvement_threshold']:
result[...,[src,dst]] = result[...,[dst,src]]
permutation_sequence[src], permutation_sequence[dst] = permutation_sequence[dst], permutation_sequence[src]
real_swap_num += 1
duration = time.perf_counter() - start_time
print("\tFinally swap {} channel pairs until the search time limit expires.".format(real_swap_num))
elif options['strategy'] == 'user defined': # need to get the permutated matrix (result) by applying customized permutation search function
print("[accelerated_search_for_good_permutation] Use the user customized permutation search function!")
else:
print("[accelerated_search_for_good_permutation] Cannot find the implementation of the required strategy!")
print("[accelerated_search_for_good_permutation] Take {:.4f} seconds to search the permutation sequence.".format(duration))
# In the new version of Exhaustive_Search function, there’s no need to use the find_permutation(result, input_matrix) function
# to recover the permutation sequence applied to the input_matrix to get the result separately any more.
#start_time_find_permutation = time.perf_counter()
#permutation_sequence = find_permutation(result, input_matrix)
#duration_find_permutation = time.perf_counter() - start_time_find_permutation
#print("[accelerated_search_for_good_permutation] Take {:.4f} seconds to finish find_permutation function.".format(duration_find_permutation))
#print("[accelerated_search_for_good_permutation] The permutation sequence is: {:}".format(permutation_sequence))
#print("[accelerated_search_for_good_permutation] The length of permutation sequence is: {:}".format(len(permutation_sequence)))
return permutation_sequence
from .permutation_utilities import *
################################################################################################################
# Exhaustive
# Try them all
# - order of columns within a group doesn't matter
# - order of groups doesn't matter
# - we can eliminate effective duplicates by defining aunique combination to be a sorted list of sorted groups
################################################################################################################
####################################################################
# generate unique permutations
####################################################################
# check if adding a column index to a current permutation would keep it in canonical form
# assumes that perm is in canonical form already!
def is_canonical(perm, col):
# if it's a new group
if len(perm) % 4 == 0:
# every column ID < col needs to be in the permutation already
for val in range(col):
if val not in perm:
return False
# this new group needs to be sorted w.r.t. the previous group
return col > perm[-4]
# not a new group, just check to see if it will still be sorted
return col > perm[-1]
# recursive: build a unique permutation one column index at a time
def generate_unique_combinations(built_permutation, remaining_columns, full_permutation_list, group_width):
# base case: nothing else to add
if len(remaining_columns) == 0:
full_permutation_list.append(np.copy(built_permutation))
if len(full_permutation_list) % 1000000 == 0:
print(f"{len(full_permutation_list)} unique permutations found so far")
# still more choices to make, so add each remaining column in turn column if it keeps everything sorted
else:
for c in range(len(remaining_columns)):
# to satisfy our immutables (values within groups are sorted, groups are globally sorted),
# only add this column if either:
# it's starting a new group and is larger than the previous group's first entry
# OR
# it's larger than the last value in the built_permutation
col_to_add = remaining_columns[c]
if is_canonical(built_permutation, col_to_add):
# add the column to the running permutation, remove it from remaining columns
built_permutation.append(col_to_add)
remaining_columns.pop(c)
# recurse
generate_unique_combinations(built_permutation, remaining_columns, full_permutation_list, group_width)
# remove the most recent column and put it back on the remaining column list where we found it (sorted)
remaining_columns.insert(c, built_permutation.pop(-1))
import pickle
import os.path
from os import path
master_unique_permutation_list = {}
def generate_all_unique_combinations(C, M, must_use_all_groups = False):
global master_unique_permutation_list
if len(master_unique_permutation_list) == 0 and path.exists("master_list.pkl"):
with open("master_list.pkl","rb") as cache:
master_unique_permutation_list = pickle.load(cache)
if (C,M) not in master_unique_permutation_list:
full_permutation_list = []
generate_unique_combinations([0], [c for c in range(1,C)], full_permutation_list, M)
master_unique_permutation_list[(C,M)] = full_permutation_list
with open("master_list.pkl", "wb") as cache:
pickle.dump(master_unique_permutation_list, cache)
unique_permutations = master_unique_permutation_list[(C,M)]
return unique_permutations
# analytical solution
import math
def predict_unique_combinations(C, M):
assert(C%M==0)
G = int(C/M)
return int(int(math.factorial(C)) / (int(math.pow(math.factorial(M),G)) * math.factorial(G)))
#################################################################
# exhaustively try all unique permutations
#################################################################
# exhaustively search the entire matrix
def search_matrix(matrix, group_width):
# give up quickly if we'd go on forever
prediction = predict_unique_combinations(matrix.shape[1], group_width)
best_permutation = [c for c in range(matrix.shape[1])]
if prediction > 1e10:
print(f"There are {prediction} unique combinations with {matrix.shape[1]} columns and a group width of {group_width}, not searching.")
return matrix, prediction, best_permutation
start_time = time.perf_counter()
full_permutation_list = generate_all_unique_combinations(matrix.shape[1], group_width)
# found them, now try them
best_improvement = 0.0
base_sum = sum_after_2_to_4(matrix)
for i in range(1,len(full_permutation_list)):
permutation = full_permutation_list[i]
permuted = matrix[:, permutation]
cur_improvement = sum_after_2_to_4(permuted) - base_sum
if (cur_improvement > best_improvement):
best_improvement = cur_improvement
best_permutation = permutation
seconds = time.perf_counter() - start_time
return matrix[:, best_permutation], seconds, best_permutation, best_improvement
#############
# Stripe group handling
#############
# gather stripes from a larger matrix into a single matrix
def collect_stripes(matrix, stripes, group_width):
subset = np.zeros((matrix.shape[0], len(stripes)*group_width))
#print("[Debug][collect_stripes] matrix shape info: {}".format(matrix.shape))
#print("[Debug][collect_stripes] subset info: {}, {}, {}".format(matrix.shape[0], len(stripes), group_width))
for s,stripe in enumerate(stripes):
#print("[Debug][collect_stripes] s: {}, stripe: {}".format(s, stripe))
subset[...,s*group_width:s*group_width+group_width] = matrix[...,stripe*group_width:stripe*group_width+group_width]
return subset
# apply the stripe group permutation to the entire permutation
def apply_stripe_group_permutation(sgp, stripes, group_width, permutation):
new_permutation = permutation.copy()
for subset_idx in range(len(sgp)):
dst_stripe_idx = stripes[int(subset_idx / group_width)]
dst_col_idx = subset_idx % group_width
subset_val = sgp[subset_idx]
src_stripe_idx = stripes[int(subset_val / group_width)]
src_col_idx = subset_val % group_width
new_permutation[dst_stripe_idx*group_width + dst_col_idx] = permutation[src_stripe_idx*group_width + src_col_idx]
return new_permutation
# generate all possible stripe groups
def generate_stripe_groups(num_stripes, window_size):
stripe_array = [[c] for c in range(num_stripes)]
next_stripe_array = []
for w in range(1, window_size):
for g in range(len(stripe_array)):
start_c = stripe_array[g][w-1]+1
group = stripe_array[g]
for c in range(start_c, num_stripes):
new_group = group.copy()
new_group.append(c)
next_stripe_array.append(new_group)
stripe_array = next_stripe_array
next_stripe_array = []
return set(tuple(stripe_array[g]) for g in range(len(stripe_array)))
# It is not safe to just reset the stripe_set as None here.
# When calling the Exhaustive_Search in E2E search, the stripe_set will not be reset as None.
stripe_set = None
stripe_set_config = None
# build the stripe map
def build_stripe_map(matrix, group_width, window_size, stripe_map, stripe_ids, perm_map, used_stripes):
global stripe_set, stripe_set_config
#print("[Debug][build_stripe_map] Now the stripe_set value is: {}".format(stripe_set))
window_size = int(window_size / group_width)
if stripe_set is None or stripe_set_config is None or stripe_set_config != (group_width, window_size):
num_stripes = int(matrix.shape[1] / group_width)
assert(group_width * num_stripes == matrix.shape[1])
stripe_set = generate_stripe_groups(num_stripes, window_size)
#print("[Debug][build_stripe_map] Update stripe_set value as: {}".format(stripe_set))
stripe_set_config = (group_width, window_size)
# step through each, update the stripe_map/stripe_ids if necessary
updates = 0
use_cuda = use_gpu()
gpu_list = []
gpu_groups = []
for i,s in enumerate(stripe_set):
sg = [] # build the group of stripes, check if any members changed
need_update = i >= len(stripe_map)
for stripe in s:
sg.append(stripe)
if stripe in used_stripes:
need_update = True
# pre-populate if we're building fresh
if i >= len(stripe_map):
stripe_ids.append(sg)
stripe_map.append(0.)
perm_map.append([c for c in range(group_width * window_size)])
# update entries if needed (only stripe_map and perm_map)
if need_update:
updates += 1
if not use_cuda: # do the work here if using the CPU
subset = collect_stripes(matrix, sg, group_width)
sub_result, sub_duration, permutation, improvement = search_matrix(subset, group_width)
stripe_map[i] = improvement
perm_map[i] = permutation
else: # otherwise, just track the work needed to farm off to the GPU
gpu_groups.append(sg)
gpu_list.append(i)
if use_cuda: # if using the GPU, perform the work
matrix_view = np.copy(matrix).astype(np.float32).flatten()
all_permutations = generate_all_unique_combinations(window_size*group_width, group_width)
num_permutations = len(all_permutations)
permutation_view = np.copy(np.asarray(all_permutations)).astype(np.uint32).flatten()
stripe_groups_view = np.asarray(gpu_groups).astype(np.uint32).flatten()
num_gpu_groups = len(gpu_list)
gpu_improvement = np.zeros((num_gpu_groups), dtype=np.float32).flatten()
gpu_permutation = np.zeros((num_gpu_groups), dtype=np.uint32).flatten()
result = permutation_search_cuda_kernels.build_permute_map(matrix_view,
matrix.shape[0],
matrix.shape[1],
stripe_groups_view,
num_gpu_groups,
window_size,
permutation_view,
window_size * group_width,
gpu_improvement,
gpu_permutation)
# put the data where python expects it
for i in range(len(gpu_list)):
stripe_map[gpu_list[i]] = gpu_improvement[i]
perm_map[gpu_list[i]] = all_permutations[gpu_permutation[i]]
return stripe_map, stripe_ids, perm_map
# start performing stripe checks
sm_perturbations = 0
sm_perturbation_limit = 0
def use_stripe_map(matrix, group_width, stripe_map, stripe_ids, perm_map, permutation):
global sm_perturbations, sm_perturbation_limit
used_stripes = []
stripe_groups_optimized = 0
improvement = 0.0
# set the traversal order
ix = np.flip(np.argsort(stripe_map)) # small to large --> large to small
for i in range(len(ix)):
stripe_group_id = ix[i]
perm = perm_map[stripe_group_id].copy()
if stripe_map[stripe_group_id] <= 0.0001:
# perturbations
if len(used_stripes) == 0 and sm_perturbations < sm_perturbation_limit:
sm_perturbations += 1
# use this permutation, but swap two channels from left/right halves to include two stripes, no matter the group size
stripe_group_id = ix[np.random.randint(len(ix))]
perm = perm_map[stripe_group_id].copy()
# a little easier to escape from
src = np.random.randint(int(len(perm)/2))
dst = int(len(perm)/2) + np.random.randint(int(len(perm)/2))
perm[src],perm[dst] = perm[dst],perm[src]
else:
break
stripe_group = stripe_ids[stripe_group_id]
# don't work on stripes we've already touched
touched_stripe = False
for stripe in stripe_group:
if stripe in used_stripes:
touched_stripe = True
if touched_stripe:
continue
# apply the permutation we've already found to this stripe group
subset = collect_stripes(matrix, stripe_group, group_width)
sub_result = subset[...,perm]
permutation = apply_stripe_group_permutation(perm, stripe_group, group_width, permutation)
# scatter the results, track what changed
for s,stripe in enumerate(stripe_group):
# see if this group is in canonical form (entry 0 a multiple of 4, contiguous values))
group = perm[s*group_width:s*group_width+group_width] # columns in this group of the used permutation
changed = False
if group[0] % 4 != 0:
changed = True
for c in range(1,group_width):
if group[c] != group[c-1]+1:
changed = True
break
# if it's not, then it changed
if changed:
used_stripes.append(stripe_group[s])
matrix[...,stripe*group_width:stripe*group_width+group_width] = sub_result[...,s*group_width:s*group_width+group_width]
improvement += stripe_map[stripe_group_id]
stripe_groups_optimized += 1
return matrix, stripe_groups_optimized, stripe_map, stripe_ids, used_stripes, improvement, permutation
# entry point for exhaustive searches - both the entire matrix, as well as stripe groups
def Exhaustive_Search(matrix, stripe_group_size=-1, escape_attempts=0, permutation=None):
global sm_perturbation_limit, sm_perturbations
sm_perturbations = 0
sm_perturbation_limit = escape_attempts
if permutation is None:
permutation = [c for c in range(matrix.shape[1])]
# It is much safer to reset the stripe_set as None in the entry point of Exhaustive_Search
global stripe_set, stripe_set_config
stripe_set = None
stripe_set_config = None
# only support N:4 for now
group_width = 4
result = np.copy(matrix)
# if the matrix is too large for a window size of 12, subdivide, then fix up with a global optimization with a window size of 8
if group_width==4 and stripe_group_size==12 and matrix.shape[1] > 512:
stripe_split = int(matrix.shape[1]/2/group_width)
col_split = stripe_split * group_width
result[:,:col_split], durationL, permutation[:col_split] = Exhaustive_Search(result[:,:col_split], stripe_group_size=stripe_group_size, escape_attempts=escape_attempts, permutation=permutation[:col_split])
result[:,col_split:], durationR, permutation[col_split:] = Exhaustive_Search(result[:,col_split:], stripe_group_size=stripe_group_size, escape_attempts=escape_attempts, permutation=permutation[col_split:])
escape_attempts = max(escape_attempts, 100)*10
result,duration,permutation = Exhaustive_Search(result, stripe_group_size=8, escape_attempts=escape_attempts, permutation=permutation)
return result, durationL+durationR+duration, permutation
# small enough to optimize the entire matrix at once
if stripe_group_size != -1 and stripe_group_size < matrix.shape[1]:
stripe_map = []
stripe_ids = []
perm_map = []
used_stripes = []
optimized_groups_count = 0
agg_improvement = 0.
cur_total_sum = sum_after_2_to_4(result)
# in practice, this work will be cached ahead of time; doing it now.
# (Reading the cached list from disk can take several seconds, which shouldn't be counted against the search, but amortized over every layer in a network)
generate_all_unique_combinations(stripe_group_size, group_width)
start_time = time.perf_counter()
while True:
#print("[Debug][Exhaustive_Search] Before entering the build_stripe_map function.")
#print("[Debug][Exhaustive_Search] Now the stripe_set value is: {}".format(stripe_set))
stripe_map, stripe_ids, perm_map = build_stripe_map(result, group_width, stripe_group_size, stripe_map, stripe_ids, perm_map, used_stripes)
result, stripe_groups_optimized, stripe_map, stripe_ids, used_stripes, improvement, permutation = use_stripe_map(result, group_width, stripe_map, stripe_ids, perm_map, permutation)
# converged?
if len(used_stripes) == 0:
break
duration = time.perf_counter() - start_time
else: # no sliding window, single iteration
print(f"Matrix has {matrix.shape[1]} columns and the search window is only {stripe_group_size}: searching exhaustively")
result, duration, permutation, improvement = search_matrix(matrix, group_width)
return result, duration, permutation
import numpy as np
import time
import ctypes
import subprocess
import os
import math
gpus_tested = False
gpus_found = 0
kernels_found = True
try:
import permutation_search_cuda as permutation_search_cuda_kernels
print(f"Found permutation search CUDA kernels")
except ImportError:
print(f"Could not find permutation search CUDA kernels, falling back to CPU path")
kernels_found = False
def use_gpu(initial_override = True):
global gpus_tested, gpus_found, kernels_found
if not gpus_tested:
if not initial_override:
gpus_tested = True
return False
try:
gpus_found = str(subprocess.check_output(["nvidia-smi", "-L"])).count('UUID')
print(f"Found {gpus_found} gpus")
except:
gpus_found = 0
print(f"Could not find nvidia-smi, please check your cuda installation")
gpus_tested = True
return gpus_found > 0 and kernels_found
##############################################################################################
# pruning utilities
##############################################################################################
## apply 2:4 to some matrix
def apply_2_to_4(matrix):
for row in range(matrix.shape[0]):
for col in range(0,matrix.shape[1],4):
ix = np.argsort(np.abs(matrix[row,col:col+4]))
matrix[row,col+ix[0]] = 0.0
matrix[row,col+ix[1]] = 0.0
return matrix
## find the sum of magnitudes if 2:4 were applied to a matrix
def sum_after_2_to_4(matrix):
#matrix = np.copy(matrix)
cur_sum = 0.0
use_cuda = use_gpu()
if not use_cuda:
start_time = time.perf_counter()
for row in range(matrix.shape[0]):
for col in range(0,matrix.shape[1],4):
ix = np.argsort(np.abs(matrix[row,col:col+4]))
cur_sum += abs(matrix[row,col+ix[2]])
cur_sum += abs(matrix[row,col+ix[3]])
np_elapsed = time.perf_counter() - start_time
else:
matrix = matrix.astype(np.float32)
cuda_sum = np.zeros((1), dtype=np.float32)
start_time = time.perf_counter()
matrix_view = np.copy(matrix).flatten()
sum_view = cuda_sum.flatten()
blocks = max(int(matrix.shape[1]/4/2), 1)
threads = min(max(math.ceil(matrix.shape[0]/4), 1), 1024)
result = permutation_search_cuda_kernels.sum_after_2_to_4(matrix_view,
matrix.shape[0],
matrix.shape[1],
0,
matrix.shape[1],
blocks,
threads,
sum_view)
cuda_elapsed = time.perf_counter() - start_time
#print(cuda_sum, cuda_elapsed, cur_sum, np_elapsed, np_elapsed/cuda_elapsed)
cur_sum = sum_view[0]
return cur_sum
## try swapping columns and tracking magnitude after pruning
def try_swap(matrix, dst, src):
src_base = sum_after_2_to_4(matrix[...,int(src/4)*4:int(src/4)*4+4])
dst_base = sum_after_2_to_4(matrix[...,int(dst/4)*4:int(dst/4)*4+4])
# swap
matrix[...,[src,dst]] = matrix[...,[dst,src]]
# check the Nx4 slices of the swapped columns
src_sum = sum_after_2_to_4(matrix[...,int(src/4)*4:int(src/4)*4+4])
dst_sum = sum_after_2_to_4(matrix[...,int(dst/4)*4:int(dst/4)*4+4])
# swap back
matrix[...,[src,dst]] = matrix[...,[dst,src]]
return src_sum + dst_sum, (src_sum + dst_sum) - (src_base + dst_base)
##############################################################################################
# permutation utilities
##############################################################################################
## find the permutation needed to make matrix A look like matrix B
def find_permutation(A, B):
permutation = []
for col in range(A.shape[1]):
Avals = A[...,col]
for bcol in range(B.shape[1]):
if np.all(Avals - B[...,bcol] == np.zeros(Avals.shape)):
permutation.append(bcol)
break
return permutation
...@@ -55,7 +55,7 @@ def main(args): ...@@ -55,7 +55,7 @@ def main(args):
step = train_loop(args, model, optimizer, step, args.num_dense_steps) step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights # simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity() ASP.compute_sparse_masks()
# train for a few steps with sparse weights # train for a few steps with sparse weights
print("SPARSE :: ",one_ll) print("SPARSE :: ",one_ll)
......
...@@ -50,7 +50,7 @@ def main(step, args, model_state_dict, optimizer_state_dict): ...@@ -50,7 +50,7 @@ def main(step, args, model_state_dict, optimizer_state_dict):
model.load_state_dict(model_state_dict) model.load_state_dict(model_state_dict)
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
print("Model sparsity is %s" % ("enabled" if ASP.sparsity_is_enabled() else "disabled")) print("Model sparsity is %s" % ("enabled" if ASP.is_sparsity_enabled() else "disabled"))
# train for a few steps with sparse weights # train for a few steps with sparse weights
print("SPARSE :: ",one_ll) print("SPARSE :: ",one_ll)
......
...@@ -59,7 +59,7 @@ def main(args): ...@@ -59,7 +59,7 @@ def main(args):
step = train_loop(args, model, optimizer, step, args.num_dense_steps) step = train_loop(args, model, optimizer, step, args.num_dense_steps)
# simulate sparsity by inserting zeros into existing dense weights # simulate sparsity by inserting zeros into existing dense weights
ASP.enable_sparsity() ASP.compute_sparse_masks()
# train for a few steps with sparse weights # train for a few steps with sparse weights
print("SPARSE :: ",one_ll) print("SPARSE :: ",one_ll)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment