"docs/source/vscode:/vscode.git/clone" did not exist on "d873acc2545e5b73be75d0e18cedfe7163febf88"
Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.nebula.constants import * from deepspeed.nebula.constants import *
class DeepSpeedNebulaConfig(DeepSpeedConfigObject): class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
def __init__(self, param_dict): def __init__(self, param_dict):
super(DeepSpeedNebulaConfig, self).__init__() super(DeepSpeedNebulaConfig, self).__init__()
...@@ -26,29 +26,18 @@ class DeepSpeedNebulaConfig(DeepSpeedConfigObject): ...@@ -26,29 +26,18 @@ class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
self._initialize(nebula_dict) self._initialize(nebula_dict)
def _initialize(self, nebula_dict): def _initialize(self, nebula_dict):
self.enabled = get_scalar_param(nebula_dict, self.enabled = get_scalar_param(nebula_dict, NEBULA_ENABLED, NEBULA_ENABLED_DEFAULT)
NEBULA_ENABLED,
NEBULA_ENABLED_DEFAULT)
self.load_path = get_scalar_param(nebula_dict, self.load_path = get_scalar_param(nebula_dict, NEBULA_LOAD_PATH, NEBULA_LOAD_PATH_DEFAULT)
NEBULA_LOAD_PATH,
NEBULA_LOAD_PATH_DEFAULT)
self.enable_nebula_load = get_scalar_param(nebula_dict, self.enable_nebula_load = get_scalar_param(nebula_dict, NEBULA_ENABLE_NEBULA_LOAD,
NEBULA_ENABLE_NEBULA_LOAD,
NEBULA_ENABLE_NEBULA_LOAD_DEFAULT) NEBULA_ENABLE_NEBULA_LOAD_DEFAULT)
self.persistent_storage_path = get_scalar_param( self.persistent_storage_path = get_scalar_param(nebula_dict, NEBULA_PERSISTENT_STORAGE_PATH,
nebula_dict, NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT)
NEBULA_PERSISTENT_STORAGE_PATH,
NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT)
self.persistent_time_interval = get_scalar_param( self.persistent_time_interval = get_scalar_param(nebula_dict, NEBULA_PERSISTENT_TIME_INTERVAL,
nebula_dict, NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT)
NEBULA_PERSISTENT_TIME_INTERVAL,
NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT)
self.num_of_version_in_retention = get_scalar_param( self.num_of_version_in_retention = get_scalar_param(nebula_dict, NEBULA_NUM_OF_VERSION_IN_RETENTION,
nebula_dict, NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT)
NEBULA_NUM_OF_VERSION_IN_RETENTION,
NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT)
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
""" # SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license. # DeepSpeed Team
"""
######################################### #########################################
# nebula # nebula
...@@ -63,24 +62,11 @@ NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2 ...@@ -63,24 +62,11 @@ NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2
# Neubla envs # Neubla envs
NEBULA_EXPORT_ENVS = [ NEBULA_EXPORT_ENVS = [
'DLTS_JOB_ID', 'DLTS_JOB_ID', 'DLTS_NUM_WORKER', 'NEBULA_PERSISTENT_STORAGE_PATH', 'NEBULA_PERSISTENT_TIME_INTERVAL',
'DLTS_NUM_WORKER', 'AML_RUN_ID', 'AZUREML_RUN_TOKEN', 'AZUREML_WORKSPACE_SCOPE', 'AZUREML_EXPERIMENT_SCOPE',
'NEBULA_PERSISTENT_STORAGE_PATH', 'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT', 'AZUREML_RUN_ID', 'NEBULA_MEMORY_BUFFER_SIZE',
'NEBULA_PERSISTENT_TIME_INTERVAL', 'AZUREML_PARAMETER_ITPJOB_NAME', 'FC_TASKROLE_NAME', 'FC_TASK_INDEX', 'MASTER_HOST', 'LOCAL_HOST',
'AML_RUN_ID', 'AZUREML_BLOB_ACCOUNT_NAME', 'AZUREML_BLOB_ACCOUNT_KEY'
'AZUREML_RUN_TOKEN',
'AZUREML_WORKSPACE_SCOPE',
'AZUREML_EXPERIMENT_SCOPE',
'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT',
'AZUREML_RUN_ID',
'NEBULA_MEMORY_BUFFER_SIZE',
'AZUREML_PARAMETER_ITPJOB_NAME',
'FC_TASKROLE_NAME',
'FC_TASK_INDEX',
'MASTER_HOST',
'LOCAL_HOST',
'AZUREML_BLOB_ACCOUNT_NAME',
'AZUREML_BLOB_ACCOUNT_KEY'
] ]
# ITP env files # ITP env files
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from . import adam from . import adam
from . import adagrad from . import adagrad
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .cpu_adagrad import DeepSpeedCPUAdagrad from .cpu_adagrad import DeepSpeedCPUAdagrad
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from deepspeed.ops.op_builder import CPUAdagradBuilder from deepspeed.ops.op_builder import CPUAdagradBuilder
...@@ -10,13 +11,7 @@ from deepspeed.utils.logging import should_log_le ...@@ -10,13 +11,7 @@ from deepspeed.utils.logging import should_log_le
class DeepSpeedCPUAdagrad(torch.optim.Optimizer): class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
optimizer_id = 0 optimizer_id = 0
def __init__(self, def __init__(self, model_params, lr=1e-2, eps=1e-10, weight_decay=0, amsgrad=False, fp32_optimizer_states=True):
model_params,
lr=1e-2,
eps=1e-10,
weight_decay=0,
amsgrad=False,
fp32_optimizer_states=True):
default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args) super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args)
...@@ -26,11 +21,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer): ...@@ -26,11 +21,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
self.fp32_optimizer_states = fp32_optimizer_states self.fp32_optimizer_states = fp32_optimizer_states
self.ds_opt_adagrad = CPUAdagradBuilder().load() self.ds_opt_adagrad = CPUAdagradBuilder().load()
self.ds_opt_adagrad.create_adagrad(self.opt_id, self.ds_opt_adagrad.create_adagrad(self.opt_id, lr, eps, weight_decay, should_log_le("info"))
lr,
eps,
weight_decay,
should_log_le("info"))
def __del__(self): def __del__(self):
# need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
...@@ -90,9 +81,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer): ...@@ -90,9 +81,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
#memory_format=torch.preserve_format) #memory_format=torch.preserve_format)
# gradient variances # gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device='cpu')
dtype=state_dtype,
device='cpu')
#memory_format=torch.preserve_format) #memory_format=torch.preserve_format)
state['step'] += 1 state['step'] += 1
...@@ -100,39 +89,21 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer): ...@@ -100,39 +89,21 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
if p.grad.is_sparse == True: if p.grad.is_sparse == True:
sparse_param = p.sparse_mask(p.grad) sparse_param = p.sparse_mask(p.grad)
sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad) sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad)
self.ds_opt_adagrad.adagrad_update(self.opt_id, self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'],
state['step'], group['weight_decay'], sparse_param.values(), p.grad.values(),
group['lr'],
group['eps'],
group['weight_decay'],
sparse_param.values(),
p.grad.values(),
sparse_exp_avg_sq.values()) sparse_exp_avg_sq.values())
p[sparse_param.indices()] = sparse_param.values() p[sparse_param.indices()] = sparse_param.values()
state['exp_avg_sq'][ state['exp_avg_sq'][sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values()
sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values()
if fp16_param_groups is not None: if fp16_param_groups is not None:
fp16_param_groups[group_id][param_id][ fp16_param_groups[group_id][param_id][sparse_param.indices()] = sparse_param.values()
sparse_param.indices()] = sparse_param.values()
else: else:
if fp16_param_groups is not None: if fp16_param_groups is not None:
self.ds_opt_adagrad.adagrad_update_copy( self.ds_opt_adagrad.adagrad_update_copy(self.opt_id, state['step'], group['lr'], group['eps'],
self.opt_id, group['weight_decay'], p.data, p.grad.data,
state['step'], state['exp_avg_sq'],
group['lr'], fp16_param_groups[group_id][param_id].data)
group['eps'],
group['weight_decay'],
p.data,
p.grad.data,
state['exp_avg_sq'],
fp16_param_groups[group_id][param_id].data)
else: else:
self.ds_opt_adagrad.adagrad_update(self.opt_id, self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'],
state['step'], group['weight_decay'], p.data, p.grad.data,
group['lr'],
group['eps'],
group['weight_decay'],
p.data,
p.grad.data,
state['exp_avg_sq']) state['exp_avg_sq'])
return loss return loss
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .cpu_adam import DeepSpeedCPUAdam from .cpu_adam import DeepSpeedCPUAdam
from .fused_adam import FusedAdam from .fused_adam import FusedAdam
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from cpuinfo import get_cpu_info from cpuinfo import get_cpu_info
...@@ -16,8 +17,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -16,8 +17,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
model_params, model_params,
lr=1e-3, lr=1e-3,
bias_correction=True, bias_correction=True,
betas=(0.9, betas=(0.9, 0.999),
0.999),
eps=1e-8, eps=1e-8,
weight_decay=0, weight_decay=0,
amsgrad=False, amsgrad=False,
...@@ -76,14 +76,12 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -76,14 +76,12 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args) super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
cpu_info = get_cpu_info() cpu_info = get_cpu_info()
self.cpu_vendor = cpu_info["vendor_id_raw"].lower( self.cpu_vendor = cpu_info["vendor_id_raw"].lower() if "vendor_id_raw" in cpu_info else "unknown"
) if "vendor_id_raw" in cpu_info else "unknown"
if "amd" in self.cpu_vendor: if "amd" in self.cpu_vendor:
for group_id, group in enumerate(self.param_groups): for group_id, group in enumerate(self.param_groups):
for param_id, p in enumerate(group['params']): for param_id, p in enumerate(group['params']):
if p.dtype == torch.half: if p.dtype == torch.half:
logger.warning( logger.warning("FP16 params for CPUAdam may not work on AMD CPUs")
"FP16 params for CPUAdam may not work on AMD CPUs")
break break
else: else:
continue continue
...@@ -95,13 +93,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -95,13 +93,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
self.fp32_optimizer_states = fp32_optimizer_states self.fp32_optimizer_states = fp32_optimizer_states
self.ds_opt_adam = CPUAdamBuilder().load() self.ds_opt_adam = CPUAdamBuilder().load()
self.ds_opt_adam.create_adam(self.opt_id, self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode,
lr,
betas[0],
betas[1],
eps,
weight_decay,
adamw_mode,
should_log_le("info")) should_log_le("info"))
def __del__(self): def __del__(self):
...@@ -168,45 +160,22 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -168,45 +160,22 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
state_dtype = torch.float if self.fp32_optimizer_states else p.dtype state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
# gradient momentums # gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
dtype=state_dtype,
device=device)
#memory_format=torch.preserve_format) #memory_format=torch.preserve_format)
# gradient variances # gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
dtype=state_dtype,
device=device)
#memory_format=torch.preserve_format) #memory_format=torch.preserve_format)
state['step'] += 1 state['step'] += 1
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
if fp16_param_groups is not None: if fp16_param_groups is not None:
self.ds_opt_adam.adam_update_copy( self.ds_opt_adam.adam_update_copy(self.opt_id, state['step'], group['lr'], beta1, beta2,
self.opt_id, group['eps'], group['weight_decay'], group['bias_correction'],
state['step'], p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'],
group['lr'], fp16_param_groups[group_id][param_id].data)
beta1,
beta2,
group['eps'],
group['weight_decay'],
group['bias_correction'],
p.data,
p.grad.data,
state['exp_avg'],
state['exp_avg_sq'],
fp16_param_groups[group_id][param_id].data)
else: else:
self.ds_opt_adam.adam_update(self.opt_id, self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
state['step'], group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
group['lr'], state['exp_avg'], state['exp_avg_sq'])
beta1,
beta2,
group['eps'],
group['weight_decay'],
group['bias_correction'],
p.data,
p.grad.data,
state['exp_avg'],
state['exp_avg_sq'])
return loss return loss
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85 This file is adapted from fused adam in NVIDIA/apex, commit a109f85
''' """
import torch import torch
from .multi_tensor_apply import MultiTensorApply from .multi_tensor_apply import MultiTensorApply
...@@ -47,12 +49,12 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -47,12 +49,12 @@ class FusedAdam(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond: .. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, def __init__(self,
params, params,
lr=1e-3, lr=1e-3,
bias_correction=True, bias_correction=True,
betas=(0.9, betas=(0.9, 0.999),
0.999),
eps=1e-8, eps=1e-8,
adam_w_mode=True, adam_w_mode=True,
weight_decay=0., weight_decay=0.,
...@@ -61,11 +63,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -61,11 +63,7 @@ class FusedAdam(torch.optim.Optimizer):
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay)
super(FusedAdam, self).__init__(params, defaults) super(FusedAdam, self).__init__(params, defaults)
self.adam_w_mode = 1 if adam_w_mode else 0 self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none self.set_grad_none = set_grad_none
...@@ -83,12 +81,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -83,12 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
else: else:
super(FusedAdam, self).zero_grad() super(FusedAdam, self).zero_grad()
def step(self, def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None):
closure=None,
grads=None,
output_params=None,
scale=None,
grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -121,8 +114,7 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -121,8 +114,7 @@ class FusedAdam(torch.optim.Optimizer):
continue continue
if p.grad.data.is_sparse: if p.grad.data.is_sparse:
raise RuntimeError( raise RuntimeError(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead' 'FusedAdam does not support sparse gradients, please consider SparseAdam instead')
)
state = self.state[p] state = self.state[p]
# State initialization # State initialization
...@@ -151,35 +143,13 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -151,35 +143,13 @@ class FusedAdam(torch.optim.Optimizer):
if (len(g_16) > 0): if (len(g_16) > 0):
state['step'] += 1 state['step'] += 1
multi_tensor_applier(self.multi_tensor_adam, multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
self._dummy_overflow_buf, group['lr'], beta1, beta2, group['eps'], state['step'], self.adam_w_mode,
[g_16, bias_correction, group['weight_decay'])
p_16,
m_16,
v_16],
group['lr'],
beta1,
beta2,
group['eps'],
state['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
if (len(g_32) > 0): if (len(g_32) > 0):
state['step'] += 1 state['step'] += 1
multi_tensor_applier(self.multi_tensor_adam, multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
self._dummy_overflow_buf, group['lr'], beta1, beta2, group['eps'], state['step'], self.adam_w_mode,
[g_32, bias_correction, group['weight_decay'])
p_32,
m_32,
v_32],
group['lr'],
beta1,
beta2,
group['eps'],
state['step'],
self.adam_w_mode,
bias_correction,
group['weight_decay'])
return loss return loss
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex, commit a109f85 This file is adapted from NVIDIA/apex, commit a109f85
''' """
class MultiTensorApply(object): class MultiTensorApply(object):
def __init__(self, chunk_size): def __init__(self, chunk_size):
self.chunk_size = chunk_size self.chunk_size = chunk_size
......
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team. # SPDX-License-Identifier: Apache-2.0
Licensed under the MIT license.
''' # DeepSpeed Team
from ..op_builder import AsyncIOBuilder from ..op_builder import AsyncIOBuilder
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .fused_lamb import FusedLamb from .fused_lamb import FusedLamb
''' # Copyright (c) Microsoft Corporation.
Copyright 2019 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
''' """
import types import types
import torch import torch
from deepspeed.ops.op_builder import FusedLambBuilder from deepspeed.ops.op_builder import FusedLambBuilder
...@@ -35,12 +37,12 @@ class FusedLamb(torch.optim.Optimizer): ...@@ -35,12 +37,12 @@ class FusedLamb(torch.optim.Optimizer):
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01) min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb! amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb!
""" """
def __init__(self, def __init__(self,
params, params,
lr=1e-3, lr=1e-3,
bias_correction=True, bias_correction=True,
betas=(0.9, betas=(0.9, 0.999),
0.999),
eps=1e-8, eps=1e-8,
eps_inside_sqrt=False, eps_inside_sqrt=False,
weight_decay=0., weight_decay=0.,
...@@ -64,12 +66,7 @@ class FusedLamb(torch.optim.Optimizer): ...@@ -64,12 +66,7 @@ class FusedLamb(torch.optim.Optimizer):
self.eps_mode = 0 if eps_inside_sqrt else 1 self.eps_mode = 0 if eps_inside_sqrt else 1
self.lamb_coeffs = [] self.lamb_coeffs = []
def step(self, def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None):
closure=None,
grads=None,
output_params=None,
scale=1.,
grad_norms=None):
"""Performs a single optimization step. """Performs a single optimization step.
Arguments: Arguments:
...@@ -114,7 +111,8 @@ class FusedLamb(torch.optim.Optimizer): ...@@ -114,7 +111,8 @@ class FusedLamb(torch.optim.Optimizer):
#remove the previous coeffs #remove the previous coeffs
del self.lamb_coeffs[:] del self.lamb_coeffs[:]
for group, grads_this_group, output_params_this_group, grad_norm_group in zip(self.param_groups, grads_group, output_params_group, grad_norms): for group, grads_this_group, output_params_this_group, grad_norm_group in zip(
self.param_groups, grads_group, output_params_group, grad_norms):
if grads_this_group is None: if grads_this_group is None:
grads_this_group = [None] * len(group['params']) grads_this_group = [None] * len(group['params'])
if output_params_this_group is None: if output_params_this_group is None:
...@@ -127,7 +125,8 @@ class FusedLamb(torch.optim.Optimizer): ...@@ -127,7 +125,8 @@ class FusedLamb(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
for p, grad, output_param, grad_norm in zip(group['params'], grads_this_group, output_params_this_group, grad_norm_group): for p, grad, output_param, grad_norm in zip(group['params'], grads_this_group, output_params_this_group,
grad_norm_group):
# compute combined scale factor for this group # compute combined scale factor for this group
combined_scale = scale combined_scale = scale
...@@ -162,24 +161,10 @@ class FusedLamb(torch.optim.Optimizer): ...@@ -162,24 +161,10 @@ class FusedLamb(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
out_p = torch.tensor( out_p = torch.tensor([], dtype=torch.float) if output_param is None else output_param
[], lamb_coeff = self.fused_lamb_cuda.lamb(p.data, out_p, exp_avg, exp_avg_sq, grad, group['lr'], beta1,
dtype=torch.float) if output_param is None else output_param beta2, max_coeff, min_coeff, group['eps'], combined_scale,
lamb_coeff = self.fused_lamb_cuda.lamb(p.data, state['step'], self.eps_mode, bias_correction,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
max_coeff,
min_coeff,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay']) group['weight_decay'])
self.lamb_coeffs.append(lamb_coeff) self.lamb_coeffs.append(lamb_coeff)
return loss return loss
......
import copy
import torch
import deepspeed
from deepspeed.ops import DeepSpeedTransformerConfig
def _copy_child_transformer_state(new_module, orig_child, pre_layer_norm):
# copy relevant state from original child -> new module
qw = orig_child.attention.self.query.weight
qb = orig_child.attention.self.query.bias
kw = orig_child.attention.self.key.weight
kb = orig_child.attention.self.key.bias
vw = orig_child.attention.self.value.weight
vb = orig_child.attention.self.value.bias
qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)
#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)
new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = orig_child.attention.output.dense.weight
new_module.attn_ob.data = orig_child.attention.output.dense.bias
if pre_layer_norm:
attention_layernorm = orig_child.PostAttentionLayerNorm
else:
attention_layernorm = orig_child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if pre_layer_norm:
intermediate_ff = orig_child.intermediate.dense_act
else:
intermediate_ff = orig_child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = orig_child.output.dense.weight
new_module.output_b.data = orig_child.output.dense.bias
if pre_layer_norm:
transformer_layernorm = orig_child.PreAttentionLayerNorm
else:
transformer_layernorm = orig_child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias
def _replace_transformer_layer(orig_layer_impl, model, transformer_config):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
transformer_config (dict): deepspeed transformer layer config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)
_copy_child_transformer_state(new_module,
child,
transformer_config.pre_layer_norm)
return new_module
return _replace_module(model=model,
orig_class=orig_layer_impl,
replace_fn=replace_fn)
def replace_module(orig_module_impl, model, replacement_module_config):
""" Replace client module
Arguments:
orig_module_impl (torch.nn.Module): original module implementation to replace,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
replacement_module_config (dict): deepspeed replacement module config (e.g., DeepSpeedTransformerConfig) .
Returns:
Updated nn.module with replaced modules
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return _replace_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
transformer_config=replacement_module_config)
def _revert_transformer_layer(orig_layer_impl, model, bert_config, transformer_config):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
transformer_config (dict): deepspeed tranformer config used for replacement
Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)
# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data
qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb
orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data
attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if transformer_config.pre_layer_norm:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if transformer_config.pre_layer_norm:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b
orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data
transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if transformer_config.pre_layer_norm:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module
return _replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)
def revert_module(orig_module_impl,
model,
orig_module_config,
replacement_module_config):
""" Revert DeepSpeed's module back to original client module
Arguments:
orig_module_impl (torch.nn.Module): the original module that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
orig_module_config (dict): original module configuration
replacement_module_config (dict): replacement deepspeed module configuration
Returns:
Updated nn.module with original bert-style transformer layers
"""
assert isinstance(replacement_module_config, DeepSpeedTransformerConfig), \
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return _revert_transformer_layer(orig_layer_impl=orig_module_impl,
model=model,
bert_config=orig_module_config,
transformer_config=replacement_module_config)
def _replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.
Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module_using_policies(model, policy)
def _replace_module_using_policies(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_replace_module_using_policies(child, policies)
return model
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .quantizer import ds_quantizer from .quantizer import ds_quantizer
''' # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import torch import torch
from deepspeed.ops.op_builder import QuantizerBuilder from deepspeed.ops.op_builder import QuantizerBuilder
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .dropping_utils import gpt_sample_tokens, bert_sample_tokens, GatherTokens, ScatterTokens from .dropping_utils import gpt_sample_tokens, bert_sample_tokens, GatherTokens, ScatterTokens
""" # Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import torch import torch
from deepspeed.ops.op_builder import RandomLTDBuilder from deepspeed.ops.op_builder import RandomLTDBuilder
...@@ -23,9 +25,7 @@ def gpt_sample_tokens(reserved_length: int, ...@@ -23,9 +25,7 @@ def gpt_sample_tokens(reserved_length: int,
prob_dist = torch.ones((layers * batch_size, seq_length), device=device) prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
sampled_indices = torch.multinomial(prob_dist, reserved_length) sampled_indices = torch.multinomial(prob_dist, reserved_length)
sampled_indices = sampled_indices.reshape(layers, sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
batch_size,
reserved_length).to(torch.int32)
global random_ltd_module global random_ltd_module
if random_ltd_module is None: if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load() random_ltd_module = RandomLTDBuilder().load()
...@@ -59,9 +59,7 @@ def bert_sample_tokens(reserved_length: int, ...@@ -59,9 +59,7 @@ def bert_sample_tokens(reserved_length: int,
prob_dist = torch.ones((layers * batch_size, seq_length), device=device) prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
sampled_indices = torch.multinomial(prob_dist, reserved_length) sampled_indices = torch.multinomial(prob_dist, reserved_length)
sampled_indices = sampled_indices.reshape(layers, sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
batch_size,
reserved_length).to(torch.int32)
global random_ltd_module global random_ltd_module
if random_ltd_module is None: if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load() random_ltd_module = RandomLTDBuilder().load()
...@@ -82,11 +80,9 @@ def bert_sample_tokens(reserved_length: int, ...@@ -82,11 +80,9 @@ def bert_sample_tokens(reserved_length: int,
class GatherTokens(torch.autograd.Function): class GatherTokens(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx, activations: torch.Tensor, sorted_indices: torch.Tensor, batch_first: bool):
activations: torch.Tensor,
sorted_indices: torch.Tensor,
batch_first: bool):
global random_ltd_module global random_ltd_module
if random_ltd_module is None: if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load() random_ltd_module = RandomLTDBuilder().load()
...@@ -104,25 +100,18 @@ class GatherTokens(torch.autograd.Function): ...@@ -104,25 +100,18 @@ class GatherTokens(torch.autograd.Function):
activations, sorted_indices = ctx.saved_tensors activations, sorted_indices = ctx.saved_tensors
batch_first = ctx.batch_first batch_first = ctx.batch_first
return random_ltd_module.token_scatter_(a_gradients, return random_ltd_module.token_scatter_(a_gradients, g_gradients, sorted_indices, batch_first), None, None
g_gradients,
sorted_indices,
batch_first), None, None
class ScatterTokens(torch.autograd.Function): class ScatterTokens(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx, all_activations: torch.Tensor, layer_activations: torch.Tensor, sorted_indices: torch.Tensor,
all_activations: torch.Tensor,
layer_activations: torch.Tensor,
sorted_indices: torch.Tensor,
batch_first: bool): batch_first: bool):
global random_ltd_module global random_ltd_module
if random_ltd_module is None: if random_ltd_module is None:
random_ltd_module = RandomLTDBuilder().load() random_ltd_module = RandomLTDBuilder().load()
scatter_results = random_ltd_module.token_scatter_(all_activations.clone(), scatter_results = random_ltd_module.token_scatter_(all_activations.clone(), layer_activations, sorted_indices,
layer_activations,
sorted_indices,
batch_first) batch_first)
ctx.save_for_backward(sorted_indices) ctx.save_for_backward(sorted_indices)
...@@ -139,7 +128,5 @@ class ScatterTokens(torch.autograd.Function): ...@@ -139,7 +128,5 @@ class ScatterTokens(torch.autograd.Function):
sorted_indices, = ctx.saved_tensors sorted_indices, = ctx.saved_tensors
batch_first = ctx.batch_first batch_first = ctx.batch_first
ret_val = random_ltd_module.token_gather(out_gradients, ret_val = random_ltd_module.token_gather(out_gradients, sorted_indices, batch_first)
sorted_indices,
batch_first)
return out_gradients, ret_val, None, None return out_gradients, ret_val, None, None
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .sparsity_config import SparsityConfig, DenseSparsityConfig, FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig, LocalSlidingWindowSparsityConfig from .sparsity_config import SparsityConfig, DenseSparsityConfig, FixedSparsityConfig, VariableSparsityConfig, BigBirdSparsityConfig, BSLongformerSparsityConfig, LocalSlidingWindowSparsityConfig
from .sparse_self_attention import SparseSelfAttention from .sparse_self_attention import SparseSelfAttention
......
""" # Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team # SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
from torch import nn from torch import nn
from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig
...@@ -13,6 +14,7 @@ class BertSparseSelfAttention(nn.Module): ...@@ -13,6 +14,7 @@ class BertSparseSelfAttention(nn.Module):
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial. For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
""" """
def __init__( def __init__(
self, self,
config, config,
...@@ -29,10 +31,8 @@ class BertSparseSelfAttention(nn.Module): ...@@ -29,10 +31,8 @@ class BertSparseSelfAttention(nn.Module):
super(BertSparseSelfAttention, self).__init__() super(BertSparseSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError("The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads))
"heads (%d)" % (config.hidden_size,
config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
...@@ -44,8 +44,7 @@ class BertSparseSelfAttention(nn.Module): ...@@ -44,8 +44,7 @@ class BertSparseSelfAttention(nn.Module):
self.sparse_self_attention = SparseSelfAttention(sparsity_config) self.sparse_self_attention = SparseSelfAttention(sparsity_config)
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
......
'''Copyright The Microsoft DeepSpeed Team''' # Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a # DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py # https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
...@@ -12,29 +15,8 @@ from deepspeed.accelerator import get_accelerator ...@@ -12,29 +15,8 @@ from deepspeed.accelerator import get_accelerator
@triton.jit @triton.jit
def _kernel(A, def _kernel(A, B, C, stride_za, stride_ha, stride_ma, stride_ka, stride_zb, stride_hb, stride_kb, stride_nb, stride_zc,
B, stride_hc, stride_mc, stride_nc, DS0, DS1, SDD_K, SDD_off_width, lut, locks, nlocks, **meta):
C,
stride_za,
stride_ha,
stride_ma,
stride_ka,
stride_zb,
stride_hb,
stride_kb,
stride_nb,
stride_zc,
stride_hc,
stride_mc,
stride_nc,
DS0,
DS1,
SDD_K,
SDD_off_width,
lut,
locks,
nlocks,
**meta):
TM = meta['TM'] TM = meta['TM']
TN = meta['TN'] TN = meta['TN']
TK = meta['TK'] TK = meta['TK']
...@@ -194,8 +176,7 @@ def _kernel(A, ...@@ -194,8 +176,7 @@ def _kernel(A,
tl.store(pc, c, mask=checkc) tl.store(pc, c, mask=checkc)
# accumulate partial results using spin-locks # accumulate partial results using spin-locks
else: else:
plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id( plock = locks + tl.program_id(2) * nlocks * tl.num_programs(1) + tl.program_id(1) * nlocks + lockid - 1
1) * nlocks + lockid - 1
pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks pcount = plock + tl.num_programs(2) * tl.num_programs(1) * nlocks
while tl.atomic_cas(plock, 0, 1) == 1: while tl.atomic_cas(plock, 0, 1) == 1:
pass pass
...@@ -292,10 +273,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -292,10 +273,7 @@ class _sparse_matmul(torch.autograd.Function):
#segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width) #segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
start_width = (128 if block > 16 else 32) // block start_width = (128 if block > 16 else 32) // block
layout = layout.type(torch.int32) layout = layout.type(torch.int32)
segmented = libtriton.superblock(layout.data_ptr(), segmented = libtriton.superblock(layout.data_ptr(), layout.shape[0], layout.shape[1], layout.shape[2],
layout.shape[0],
layout.shape[1],
layout.shape[2],
start_width) start_width)
luts, widths, packs = [], [], [] luts, widths, packs = [], [], []
for size, nnz in segmented: for size, nnz in segmented:
...@@ -317,19 +295,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -317,19 +295,7 @@ class _sparse_matmul(torch.autograd.Function):
return luts, None, widths, packs return luts, None, widths, packs
@staticmethod @staticmethod
def _sdd_matmul(a, def _sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, luts, num_locks, widths, packs, bench, time):
b,
trans_a,
trans_b,
trans_c,
spdims,
block,
luts,
num_locks,
widths,
packs,
bench,
time):
if trans_c: if trans_c:
a, b = b, a a, b = b, a
trans_a, trans_b = not trans_b, not trans_a trans_a, trans_b = not trans_b, not trans_a
...@@ -339,9 +305,8 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -339,9 +305,8 @@ class _sparse_matmul(torch.autograd.Function):
b_dim = -1 if trans_b else -2 b_dim = -1 if trans_b else -2
a_inner, b_inner = a.shape[a_dim], b.shape[b_dim] a_inner, b_inner = a.shape[a_dim], b.shape[b_dim]
if a_inner != b_inner: if a_inner != b_inner:
raise ValueError( raise ValueError(f"Size of tensor A along the {a_dim} dim ({a_inner}) must match size "
f"Size of tensor A along the {a_dim} dim ({a_inner}) must match size " f"of tensor B along the {b_dim} dim ({b_inner})")
f"of tensor B along the {b_dim} dim ({b_inner})")
if a_inner % 16 != 0: if a_inner % 16 != 0:
raise ValueError('Reduction size for SDD must be a multiple of 16') raise ValueError('Reduction size for SDD must be a multiple of 16')
...@@ -356,12 +321,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -356,12 +321,7 @@ class _sparse_matmul(torch.autograd.Function):
device = a.device device = a.device
# create kernel # create kernel
total_width = sum([width * pack * pack for width, pack in zip(widths, packs)]) total_width = sum([width * pack * pack for width, pack in zip(widths, packs)])
c = torch.empty((batch_size, c = torch.empty((batch_size, total_width, block, block), dtype=dtype, device=a.device)
total_width,
block,
block),
dtype=dtype,
device=a.device)
for lut, width, pack in zip(luts, widths, packs): for lut, width, pack in zip(luts, widths, packs):
F32TK = [8, 16] F32TK = [8, 16]
F16TK = [16] F16TK = [16]
...@@ -387,12 +347,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -387,12 +347,7 @@ class _sparse_matmul(torch.autograd.Function):
max_width = 49152 max_width = 49152
total = 0 if bench else None total = 0 if bench else None
for off_width in range(0, width, max_width): for off_width in range(0, width, max_width):
grid = lambda meta: [ grid = lambda meta: [meta['TZ'], min(max_width, width - off_width), batch_size]
meta['TZ'],
min(max_width,
width - off_width),
batch_size
]
_kernel[grid](a, _kernel[grid](a,
b, b,
c, c,
...@@ -504,13 +459,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -504,13 +459,7 @@ class _sparse_matmul(torch.autograd.Function):
# create header # create header
width = column.size(0) width = column.size(0)
offsets += 6 * width offsets += 6 * width
header = torch.stack((offsets, header = torch.stack((offsets, segments, column, depth, lockid, maxid), dim=1).view(-1).contiguous()
segments,
column,
depth,
lockid,
maxid),
dim=1).view(-1).contiguous()
incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous() incs = torch.stack((xincs, wincs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype))) incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut # create lut
...@@ -521,19 +470,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -521,19 +470,7 @@ class _sparse_matmul(torch.autograd.Function):
return lut, num_locks, width, None return lut, num_locks, width, None
@staticmethod @staticmethod
def _dds_matmul(a, def _dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs, bench, time):
b,
trans_a,
trans_b,
trans_c,
spdims,
block,
lut,
num_locks,
width,
packs,
bench,
time):
global triton global triton
if triton is None: if triton is None:
triton = importlib.import_module('triton') triton = importlib.import_module('triton')
...@@ -548,16 +485,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -548,16 +485,7 @@ class _sparse_matmul(torch.autograd.Function):
BS2 = block * spdims[1 if trans_b else 2] BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype dtype = a.dtype
# kernel # kernel
meta = { meta = {'TN': block, 'TM': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, 'SDD': False, 'DSD': False, 'DDS': True}
'TN': block,
'TM': 128,
'TK': 16,
'BLOCK': block,
'TZ': 1,
'SDD': False,
'DSD': False,
'DDS': True
}
# output # output
CS0 = AS0 CS0 = AS0
CS1 = AS1 CS1 = AS1
...@@ -593,19 +521,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -593,19 +521,7 @@ class _sparse_matmul(torch.autograd.Function):
return c return c
@staticmethod @staticmethod
def _dsd_matmul(a, def _dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, num_locks, width, packs, bench, time):
b,
trans_a,
trans_b,
trans_c,
spdims,
block,
lut,
num_locks,
width,
packs,
bench,
time):
global triton global triton
if triton is None: if triton is None:
triton = importlib.import_module('triton') triton = importlib.import_module('triton')
...@@ -621,16 +537,7 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -621,16 +537,7 @@ class _sparse_matmul(torch.autograd.Function):
dtype = a.dtype dtype = a.dtype
# kernel # kernel
meta = { meta = {'TM': block, 'TN': 128, 'TK': 16, 'BLOCK': block, 'TZ': 1, 'SDD': False, 'DSD': True, 'DDS': False}
'TM': block,
'TN': 128,
'TK': 16,
'BLOCK': block,
'TZ': 1,
'SDD': False,
'DSD': True,
'DDS': False
}
# output # output
CS0 = BS0 CS0 = BS0
CS1 = BS1 CS1 = BS1
...@@ -665,53 +572,14 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -665,53 +572,14 @@ class _sparse_matmul(torch.autograd.Function):
**meta) **meta)
return c return c
fn = { fn = {'sdd': _sdd_matmul.__get__(object), 'dsd': _dsd_matmul.__get__(object), 'dds': _dds_matmul.__get__(object)}
'sdd': _sdd_matmul.__get__(object),
'dsd': _dsd_matmul.__get__(object),
'dds': _dds_matmul.__get__(object)
}
@staticmethod @staticmethod
def forward(ctx, def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_num_locks, c_width, c_packs,
a, c_bench, c_time, da_lut, da_num_locks, da_width, da_packs, da_bench, da_time, db_lut, db_num_locks,
b, db_width, db_packs, db_bench, db_time):
trans_a, c = _sparse_matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_num_locks, c_width,
trans_b, c_packs, c_bench, c_time)
trans_c,
mode,
spdims,
block,
c_lut,
c_num_locks,
c_width,
c_packs,
c_bench,
c_time,
da_lut,
da_num_locks,
da_width,
da_packs,
da_bench,
da_time,
db_lut,
db_num_locks,
db_width,
db_packs,
db_bench,
db_time):
c = _sparse_matmul.fn[mode](a,
b,
trans_a,
trans_b,
trans_c,
spdims,
block,
c_lut,
c_num_locks,
c_width,
c_packs,
c_bench,
c_time)
# save for backward # save for backward
ctx.save_for_backward(a, b) ctx.save_for_backward(a, b)
ctx.da_num_locks = da_num_locks ctx.da_num_locks = da_num_locks
...@@ -741,34 +609,14 @@ class _sparse_matmul(torch.autograd.Function): ...@@ -741,34 +609,14 @@ class _sparse_matmul(torch.autograd.Function):
# gradients w.r.t. a # gradients w.r.t. a
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2] mode_da = mode[1] + mode[0] + mode[2]
da = _sparse_matmul.fn[mode_da](dc, da = _sparse_matmul.fn[mode_da](dc, b, False, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
b, ctx.da_lut, ctx.da_num_locks, ctx.da_width, ctx.da_packs, ctx.da_bench,
False,
not ctx.trans_b,
ctx.trans_a,
ctx.spdims,
ctx.block,
ctx.da_lut,
ctx.da_num_locks,
ctx.da_width,
ctx.da_packs,
ctx.da_bench,
ctx.da_time) ctx.da_time)
# gradients w.r.t. b # gradients w.r.t. b
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0] mode_db = mode[2] + mode[1] + mode[0]
db = _sparse_matmul.fn[mode_db](a, db = _sparse_matmul.fn[mode_db](a, dc, not ctx.trans_a, False, ctx.trans_b, ctx.spdims, ctx.block,
dc, ctx.db_lut, ctx.db_num_locks, ctx.db_width, ctx.db_packs, ctx.db_bench,
not ctx.trans_a,
False,
ctx.trans_b,
ctx.spdims,
ctx.block,
ctx.db_lut,
ctx.db_num_locks,
ctx.db_width,
ctx.db_packs,
ctx.db_bench,
ctx.db_time) ctx.db_time)
return da, db, None, None, None,\ return da, db, None, None, None,\
None, None, None, None,\ None, None, None, None,\
...@@ -785,6 +633,7 @@ class MatMul: ...@@ -785,6 +633,7 @@ class MatMul:
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509 For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
""" """
def make_lut(self, dtype, device): def make_lut(self, dtype, device):
"""Generates the sparsity layout/s used in block-sparse matmul """Generates the sparsity layout/s used in block-sparse matmul
""" """
...@@ -797,21 +646,25 @@ class MatMul: ...@@ -797,21 +646,25 @@ class MatMul:
if self.mode == 'sdd': if self.mode == 'sdd':
c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device) c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':
c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_a, device) c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_a,
device)
elif self.mode == 'dds': elif self.mode == 'dds':
c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_b, device) c_lut, c_num_locks, c_width, c_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_b,
device)
# DA look-up table # DA look-up table
if self.mode == 'sdd': if self.mode == 'sdd':
da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, True, device) da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, True, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':
da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device) da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
elif self.mode == 'dds': elif self.mode == 'dds':
da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step, not self.trans_b, device) da_lut, da_num_locks, da_width, da_packs = _sparse_matmul.make_dxx_lut(layout, block, step,
not self.trans_b, device)
# DB look-up table # DB look-up table
if self.mode == 'sdd': if self.mode == 'sdd':
db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, False, device) db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, False, device)
elif self.mode == 'dsd': elif self.mode == 'dsd':
db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_a, device) db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_dxx_lut(layout, block, step, self.trans_a,
device)
elif self.mode == 'dds': elif self.mode == 'dds':
db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device) db_lut, db_num_locks, db_width, db_packs = _sparse_matmul.make_sdd_lut(layout, block, dtype, device)
self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\ self.lut_cache[key] = (c_lut, c_num_locks, c_width, c_packs,\
...@@ -845,11 +698,10 @@ class MatMul: ...@@ -845,11 +698,10 @@ class MatMul:
assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s" assert layout_dim in (2, 3), "Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
if not mode == 'sdd': if not mode == 'sdd':
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2 # Dims to be reduced on the 'inside' of the matmul, either -1 or -2
trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b, -2) trans_dense, trans_sparse, sparse_inner = (trans_b, trans_a, -1) if mode == 'dsd' else (trans_a, trans_b,
self.dense_inner_dim = -( -2)
(sparse_inner % 2) + 1) if not trans_dense else sparse_inner self.dense_inner_dim = -((sparse_inner % 2) + 1) if not trans_dense else sparse_inner
sparse_inner = sparse_inner if not trans_sparse else -( sparse_inner = sparse_inner if not trans_sparse else -((sparse_inner % 2) + 1)
(sparse_inner % 2) + 1)
# Inner dim of the dense input should be equal to the inner dim of the sparse input # Inner dim of the dense input should be equal to the inner dim of the sparse input
self.dense_inner_size = layout.shape[sparse_inner] * block self.dense_inner_size = layout.shape[sparse_inner] * block
...@@ -860,8 +712,7 @@ class MatMul: ...@@ -860,8 +712,7 @@ class MatMul:
if layout_dim == 2: if layout_dim == 2:
layout = layout.unsqueeze(0) layout = layout.unsqueeze(0)
layout = layout.long( layout = layout.long() # Above code assumes the layout tensor is an integral type
) # Above code assumes the layout tensor is an integral type
self.spdims = layout.shape self.spdims = layout.shape
# timings # timings
...@@ -909,31 +760,9 @@ class MatMul: ...@@ -909,31 +760,9 @@ class MatMul:
b = MatMul._pad_shape(b, self.mode == 'dds') b = MatMul._pad_shape(b, self.mode == 'dds')
# execute # execute
c = _sparse_matmul.apply(a, c = _sparse_matmul.apply(a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut,
b, c_num_locks, c_width, c_packs, self.bench, time_c, da_lut, da_num_locks, da_width,
self.trans_a, da_packs, self.bench, time_da, db_lut, db_num_locks, db_width, db_packs, self.bench,
self.trans_b,
False,
self.mode,
self.spdims,
self.block,
c_lut,
c_num_locks,
c_width,
c_packs,
self.bench,
time_c,
da_lut,
da_num_locks,
da_width,
da_packs,
self.bench,
time_da,
db_lut,
db_num_locks,
db_width,
db_packs,
self.bench,
time_db) time_db)
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input # This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
...@@ -948,9 +777,8 @@ class MatMul: ...@@ -948,9 +777,8 @@ class MatMul:
def _validate_inputs(self, a, b): def _validate_inputs(self, a, b):
if a.device != b.device: if a.device != b.device:
raise ValueError( raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
f"Inputs must be on the same device; got {a.device} for tensor A " f"and {b.device} for tensor B")
f"and {b.device} for tensor B")
if not get_accelerator().on_accelerator(a): if not get_accelerator().on_accelerator(a):
raise ValueError("Only GPU devices are supported for now") raise ValueError("Only GPU devices are supported for now")
...@@ -958,9 +786,7 @@ class MatMul: ...@@ -958,9 +786,7 @@ class MatMul:
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
a, b = a.half(), b.half() a, b = a.half(), b.half()
elif a.dtype != b.dtype: elif a.dtype != b.dtype:
raise ValueError( raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")
f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B"
)
mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
if mode != 'sdd': if mode != 'sdd':
...@@ -968,14 +794,12 @@ class MatMul: ...@@ -968,14 +794,12 @@ class MatMul:
dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A') dense, dense_name, sparse, sparse_name = (a, 'A', b, 'B') if mode == 'dds' else (b, 'B', a, 'A')
dense_inner = dense.shape[self.dense_inner_dim] dense_inner = dense.shape[self.dense_inner_dim]
if dense_inner != self.dense_inner_size: if dense_inner != self.dense_inner_size:
raise ValueError( raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim " f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")
if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape: if sparse.shape[-len(self.sparse_shape):] != self.sparse_shape:
raise ValueError( raise ValueError(f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument "
f"Expected tensor with trailing dimensions of shape {self.sparse_shape} for argument " f"{sparse_name}, got {sparse.shape}")
f"{sparse_name}, got {sparse.shape}")
def add_extra_dims(x): def add_extra_dims(x):
# Add extra leading singleton dimensions if needed # Add extra leading singleton dimensions if needed
...@@ -984,8 +808,7 @@ class MatMul: ...@@ -984,8 +808,7 @@ class MatMul:
singletons = [1] * dims_needed singletons = [1] * dims_needed
x = x.view(*singletons, *x.shape) x = x.view(*singletons, *x.shape)
elif dims_needed < 0: elif dims_needed < 0:
raise ValueError( raise ValueError("Tensors with more than 4 dimensions are not currently supported")
"Tensors with more than 4 dimensions are not currently supported")
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment