Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
5bcc463d
Commit
5bcc463d
authored
May 29, 2023
by
aiss
Browse files
update v0.9.2
parent
ac5fbab4
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
207 additions
and
718 deletions
+207
-718
deepspeed/nebula/config.py
deepspeed/nebula/config.py
+14
-25
deepspeed/nebula/constants.py
deepspeed/nebula/constants.py
+9
-23
deepspeed/ops/__init__.py
deepspeed/ops/__init__.py
+4
-1
deepspeed/ops/adagrad/__init__.py
deepspeed/ops/adagrad/__init__.py
+4
-1
deepspeed/ops/adagrad/cpu_adagrad.py
deepspeed/ops/adagrad/cpu_adagrad.py
+17
-46
deepspeed/ops/adam/__init__.py
deepspeed/ops/adam/__init__.py
+4
-1
deepspeed/ops/adam/cpu_adam.py
deepspeed/ops/adam/cpu_adam.py
+17
-48
deepspeed/ops/adam/fused_adam.py
deepspeed/ops/adam/fused_adam.py
+16
-46
deepspeed/ops/adam/multi_tensor_apply.py
deepspeed/ops/adam/multi_tensor_apply.py
+6
-3
deepspeed/ops/aio/__init__.py
deepspeed/ops/aio/__init__.py
+4
-4
deepspeed/ops/lamb/__init__.py
deepspeed/ops/lamb/__init__.py
+4
-1
deepspeed/ops/lamb/fused_lamb.py
deepspeed/ops/lamb/fused_lamb.py
+16
-31
deepspeed/ops/module_inject.py
deepspeed/ops/module_inject.py
+0
-216
deepspeed/ops/quantizer/__init__.py
deepspeed/ops/quantizer/__init__.py
+4
-1
deepspeed/ops/quantizer/quantizer.py
deepspeed/ops/quantizer/quantizer.py
+5
-3
deepspeed/ops/random_ltd/__init__.py
deepspeed/ops/random_ltd/__init__.py
+4
-1
deepspeed/ops/random_ltd/dropping_utils.py
deepspeed/ops/random_ltd/dropping_utils.py
+14
-27
deepspeed/ops/sparse_attention/__init__.py
deepspeed/ops/sparse_attention/__init__.py
+4
-1
deepspeed/ops/sparse_attention/bert_sparse_self_attention.py
deepspeed/ops/sparse_attention/bert_sparse_self_attention.py
+8
-9
deepspeed/ops/sparse_attention/matmul.py
deepspeed/ops/sparse_attention/matmul.py
+53
-230
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
deepspeed/nebula/config.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
"""
# SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
# DeepSpeed Team
"""
from
deepspeed.runtime.config_utils
import
get_scalar_param
,
DeepSpeedConfigObject
from
deepspeed.runtime.config_utils
import
get_scalar_param
,
DeepSpeedConfigObject
from
deepspeed.nebula.constants
import
*
from
deepspeed.nebula.constants
import
*
class
DeepSpeedNebulaConfig
(
DeepSpeedConfigObject
):
class
DeepSpeedNebulaConfig
(
DeepSpeedConfigObject
):
def
__init__
(
self
,
param_dict
):
def
__init__
(
self
,
param_dict
):
super
(
DeepSpeedNebulaConfig
,
self
).
__init__
()
super
(
DeepSpeedNebulaConfig
,
self
).
__init__
()
...
@@ -26,29 +26,18 @@ class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
...
@@ -26,29 +26,18 @@ class DeepSpeedNebulaConfig(DeepSpeedConfigObject):
self
.
_initialize
(
nebula_dict
)
self
.
_initialize
(
nebula_dict
)
def
_initialize
(
self
,
nebula_dict
):
def
_initialize
(
self
,
nebula_dict
):
self
.
enabled
=
get_scalar_param
(
nebula_dict
,
self
.
enabled
=
get_scalar_param
(
nebula_dict
,
NEBULA_ENABLED
,
NEBULA_ENABLED_DEFAULT
)
NEBULA_ENABLED
,
NEBULA_ENABLED_DEFAULT
)
self
.
load_path
=
get_scalar_param
(
nebula_dict
,
self
.
load_path
=
get_scalar_param
(
nebula_dict
,
NEBULA_LOAD_PATH
,
NEBULA_LOAD_PATH_DEFAULT
)
NEBULA_LOAD_PATH
,
NEBULA_LOAD_PATH_DEFAULT
)
self
.
enable_nebula_load
=
get_scalar_param
(
nebula_dict
,
self
.
enable_nebula_load
=
get_scalar_param
(
nebula_dict
,
NEBULA_ENABLE_NEBULA_LOAD
,
NEBULA_ENABLE_NEBULA_LOAD
,
NEBULA_ENABLE_NEBULA_LOAD_DEFAULT
)
NEBULA_ENABLE_NEBULA_LOAD_DEFAULT
)
self
.
persistent_storage_path
=
get_scalar_param
(
self
.
persistent_storage_path
=
get_scalar_param
(
nebula_dict
,
NEBULA_PERSISTENT_STORAGE_PATH
,
nebula_dict
,
NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT
)
NEBULA_PERSISTENT_STORAGE_PATH
,
NEBULA_PERSISTENT_STORAGE_PATH_DEFAULT
)
self
.
persistent_time_interval
=
get_scalar_param
(
self
.
persistent_time_interval
=
get_scalar_param
(
nebula_dict
,
NEBULA_PERSISTENT_TIME_INTERVAL
,
nebula_dict
,
NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT
)
NEBULA_PERSISTENT_TIME_INTERVAL
,
NEBULA_PERSISTENT_TIME_INTERVAL_DEFAULT
)
self
.
num_of_version_in_retention
=
get_scalar_param
(
self
.
num_of_version_in_retention
=
get_scalar_param
(
nebula_dict
,
NEBULA_NUM_OF_VERSION_IN_RETENTION
,
nebula_dict
,
NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT
)
NEBULA_NUM_OF_VERSION_IN_RETENTION
,
NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT
)
deepspeed/nebula/constants.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
"""
# SPDX-License-Identifier: Apache-2.0
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
# DeepSpeed Team
"""
#########################################
#########################################
# nebula
# nebula
...
@@ -63,24 +62,11 @@ NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2
...
@@ -63,24 +62,11 @@ NEBULA_NUM_OF_VERSION_IN_RETENTION_DEFAULT = 2
# Neubla envs
# Neubla envs
NEBULA_EXPORT_ENVS
=
[
NEBULA_EXPORT_ENVS
=
[
'DLTS_JOB_ID'
,
'DLTS_JOB_ID'
,
'DLTS_NUM_WORKER'
,
'NEBULA_PERSISTENT_STORAGE_PATH'
,
'NEBULA_PERSISTENT_TIME_INTERVAL'
,
'DLTS_NUM_WORKER'
,
'AML_RUN_ID'
,
'AZUREML_RUN_TOKEN'
,
'AZUREML_WORKSPACE_SCOPE'
,
'AZUREML_EXPERIMENT_SCOPE'
,
'NEBULA_PERSISTENT_STORAGE_PATH'
,
'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT'
,
'AZUREML_RUN_ID'
,
'NEBULA_MEMORY_BUFFER_SIZE'
,
'NEBULA_PERSISTENT_TIME_INTERVAL'
,
'AZUREML_PARAMETER_ITPJOB_NAME'
,
'FC_TASKROLE_NAME'
,
'FC_TASK_INDEX'
,
'MASTER_HOST'
,
'LOCAL_HOST'
,
'AML_RUN_ID'
,
'AZUREML_BLOB_ACCOUNT_NAME'
,
'AZUREML_BLOB_ACCOUNT_KEY'
'AZUREML_RUN_TOKEN'
,
'AZUREML_WORKSPACE_SCOPE'
,
'AZUREML_EXPERIMENT_SCOPE'
,
'AZUREML_RUN_HISTORY_SERVICE_ENDPOINT'
,
'AZUREML_RUN_ID'
,
'NEBULA_MEMORY_BUFFER_SIZE'
,
'AZUREML_PARAMETER_ITPJOB_NAME'
,
'FC_TASKROLE_NAME'
,
'FC_TASK_INDEX'
,
'MASTER_HOST'
,
'LOCAL_HOST'
,
'AZUREML_BLOB_ACCOUNT_NAME'
,
'AZUREML_BLOB_ACCOUNT_KEY'
]
]
# ITP env files
# ITP env files
...
...
deepspeed/ops/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.
import
adam
from
.
import
adam
from
.
import
adagrad
from
.
import
adagrad
...
...
deepspeed/ops/adagrad/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.cpu_adagrad
import
DeepSpeedCPUAdagrad
from
.cpu_adagrad
import
DeepSpeedCPUAdagrad
deepspeed/ops/adagrad/cpu_adagrad.py
View file @
5bcc463d
'''
# Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import
torch
import
torch
from
deepspeed.ops.op_builder
import
CPUAdagradBuilder
from
deepspeed.ops.op_builder
import
CPUAdagradBuilder
...
@@ -10,13 +11,7 @@ from deepspeed.utils.logging import should_log_le
...
@@ -10,13 +11,7 @@ from deepspeed.utils.logging import should_log_le
class
DeepSpeedCPUAdagrad
(
torch
.
optim
.
Optimizer
):
class
DeepSpeedCPUAdagrad
(
torch
.
optim
.
Optimizer
):
optimizer_id
=
0
optimizer_id
=
0
def
__init__
(
self
,
def
__init__
(
self
,
model_params
,
lr
=
1e-2
,
eps
=
1e-10
,
weight_decay
=
0
,
amsgrad
=
False
,
fp32_optimizer_states
=
True
):
model_params
,
lr
=
1e-2
,
eps
=
1e-10
,
weight_decay
=
0
,
amsgrad
=
False
,
fp32_optimizer_states
=
True
):
default_args
=
dict
(
lr
=
lr
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
)
default_args
=
dict
(
lr
=
lr
,
eps
=
eps
,
weight_decay
=
weight_decay
,
amsgrad
=
amsgrad
)
super
(
DeepSpeedCPUAdagrad
,
self
).
__init__
(
model_params
,
default_args
)
super
(
DeepSpeedCPUAdagrad
,
self
).
__init__
(
model_params
,
default_args
)
...
@@ -26,11 +21,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
...
@@ -26,11 +21,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
self
.
fp32_optimizer_states
=
fp32_optimizer_states
self
.
fp32_optimizer_states
=
fp32_optimizer_states
self
.
ds_opt_adagrad
=
CPUAdagradBuilder
().
load
()
self
.
ds_opt_adagrad
=
CPUAdagradBuilder
().
load
()
self
.
ds_opt_adagrad
.
create_adagrad
(
self
.
opt_id
,
self
.
ds_opt_adagrad
.
create_adagrad
(
self
.
opt_id
,
lr
,
eps
,
weight_decay
,
should_log_le
(
"info"
))
lr
,
eps
,
weight_decay
,
should_log_le
(
"info"
))
def
__del__
(
self
):
def
__del__
(
self
):
# need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
# need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
...
@@ -90,9 +81,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
...
@@ -90,9 +81,7 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
#memory_format=torch.preserve_format)
#memory_format=torch.preserve_format)
# gradient variances
# gradient variances
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
,
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
,
dtype
=
state_dtype
,
device
=
'cpu'
)
dtype
=
state_dtype
,
device
=
'cpu'
)
#memory_format=torch.preserve_format)
#memory_format=torch.preserve_format)
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
...
@@ -100,39 +89,21 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
...
@@ -100,39 +89,21 @@ class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
if
p
.
grad
.
is_sparse
==
True
:
if
p
.
grad
.
is_sparse
==
True
:
sparse_param
=
p
.
sparse_mask
(
p
.
grad
)
sparse_param
=
p
.
sparse_mask
(
p
.
grad
)
sparse_exp_avg_sq
=
state
[
'exp_avg_sq'
].
sparse_mask
(
p
.
grad
)
sparse_exp_avg_sq
=
state
[
'exp_avg_sq'
].
sparse_mask
(
p
.
grad
)
self
.
ds_opt_adagrad
.
adagrad_update
(
self
.
opt_id
,
self
.
ds_opt_adagrad
.
adagrad_update
(
self
.
opt_id
,
state
[
'step'
],
group
[
'lr'
],
group
[
'eps'
],
state
[
'step'
],
group
[
'weight_decay'
],
sparse_param
.
values
(),
p
.
grad
.
values
(),
group
[
'lr'
],
group
[
'eps'
],
group
[
'weight_decay'
],
sparse_param
.
values
(),
p
.
grad
.
values
(),
sparse_exp_avg_sq
.
values
())
sparse_exp_avg_sq
.
values
())
p
[
sparse_param
.
indices
()]
=
sparse_param
.
values
()
p
[
sparse_param
.
indices
()]
=
sparse_param
.
values
()
state
[
'exp_avg_sq'
][
state
[
'exp_avg_sq'
][
sparse_exp_avg_sq
.
indices
()]
=
sparse_exp_avg_sq
.
values
()
sparse_exp_avg_sq
.
indices
()]
=
sparse_exp_avg_sq
.
values
()
if
fp16_param_groups
is
not
None
:
if
fp16_param_groups
is
not
None
:
fp16_param_groups
[
group_id
][
param_id
][
fp16_param_groups
[
group_id
][
param_id
][
sparse_param
.
indices
()]
=
sparse_param
.
values
()
sparse_param
.
indices
()]
=
sparse_param
.
values
()
else
:
else
:
if
fp16_param_groups
is
not
None
:
if
fp16_param_groups
is
not
None
:
self
.
ds_opt_adagrad
.
adagrad_update_copy
(
self
.
ds_opt_adagrad
.
adagrad_update_copy
(
self
.
opt_id
,
state
[
'step'
],
group
[
'lr'
],
group
[
'eps'
],
self
.
opt_id
,
group
[
'weight_decay'
],
p
.
data
,
p
.
grad
.
data
,
state
[
'step'
],
state
[
'exp_avg_sq'
],
group
[
'lr'
],
fp16_param_groups
[
group_id
][
param_id
].
data
)
group
[
'eps'
],
group
[
'weight_decay'
],
p
.
data
,
p
.
grad
.
data
,
state
[
'exp_avg_sq'
],
fp16_param_groups
[
group_id
][
param_id
].
data
)
else
:
else
:
self
.
ds_opt_adagrad
.
adagrad_update
(
self
.
opt_id
,
self
.
ds_opt_adagrad
.
adagrad_update
(
self
.
opt_id
,
state
[
'step'
],
group
[
'lr'
],
group
[
'eps'
],
state
[
'step'
],
group
[
'weight_decay'
],
p
.
data
,
p
.
grad
.
data
,
group
[
'lr'
],
group
[
'eps'
],
group
[
'weight_decay'
],
p
.
data
,
p
.
grad
.
data
,
state
[
'exp_avg_sq'
])
state
[
'exp_avg_sq'
])
return
loss
return
loss
deepspeed/ops/adam/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.cpu_adam
import
DeepSpeedCPUAdam
from
.cpu_adam
import
DeepSpeedCPUAdam
from
.fused_adam
import
FusedAdam
from
.fused_adam
import
FusedAdam
deepspeed/ops/adam/cpu_adam.py
View file @
5bcc463d
'''
# Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import
torch
import
torch
from
cpuinfo
import
get_cpu_info
from
cpuinfo
import
get_cpu_info
...
@@ -16,8 +17,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
...
@@ -16,8 +17,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
model_params
,
model_params
,
lr
=
1e-3
,
lr
=
1e-3
,
bias_correction
=
True
,
bias_correction
=
True
,
betas
=
(
0.9
,
betas
=
(
0.9
,
0.999
),
0.999
),
eps
=
1e-8
,
eps
=
1e-8
,
weight_decay
=
0
,
weight_decay
=
0
,
amsgrad
=
False
,
amsgrad
=
False
,
...
@@ -76,14 +76,12 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
...
@@ -76,14 +76,12 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
super
(
DeepSpeedCPUAdam
,
self
).
__init__
(
model_params
,
default_args
)
super
(
DeepSpeedCPUAdam
,
self
).
__init__
(
model_params
,
default_args
)
cpu_info
=
get_cpu_info
()
cpu_info
=
get_cpu_info
()
self
.
cpu_vendor
=
cpu_info
[
"vendor_id_raw"
].
lower
(
self
.
cpu_vendor
=
cpu_info
[
"vendor_id_raw"
].
lower
()
if
"vendor_id_raw"
in
cpu_info
else
"unknown"
)
if
"vendor_id_raw"
in
cpu_info
else
"unknown"
if
"amd"
in
self
.
cpu_vendor
:
if
"amd"
in
self
.
cpu_vendor
:
for
group_id
,
group
in
enumerate
(
self
.
param_groups
):
for
group_id
,
group
in
enumerate
(
self
.
param_groups
):
for
param_id
,
p
in
enumerate
(
group
[
'params'
]):
for
param_id
,
p
in
enumerate
(
group
[
'params'
]):
if
p
.
dtype
==
torch
.
half
:
if
p
.
dtype
==
torch
.
half
:
logger
.
warning
(
logger
.
warning
(
"FP16 params for CPUAdam may not work on AMD CPUs"
)
"FP16 params for CPUAdam may not work on AMD CPUs"
)
break
break
else
:
else
:
continue
continue
...
@@ -95,13 +93,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
...
@@ -95,13 +93,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
self
.
fp32_optimizer_states
=
fp32_optimizer_states
self
.
fp32_optimizer_states
=
fp32_optimizer_states
self
.
ds_opt_adam
=
CPUAdamBuilder
().
load
()
self
.
ds_opt_adam
=
CPUAdamBuilder
().
load
()
self
.
ds_opt_adam
.
create_adam
(
self
.
opt_id
,
self
.
ds_opt_adam
.
create_adam
(
self
.
opt_id
,
lr
,
betas
[
0
],
betas
[
1
],
eps
,
weight_decay
,
adamw_mode
,
lr
,
betas
[
0
],
betas
[
1
],
eps
,
weight_decay
,
adamw_mode
,
should_log_le
(
"info"
))
should_log_le
(
"info"
))
def
__del__
(
self
):
def
__del__
(
self
):
...
@@ -168,45 +160,22 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
...
@@ -168,45 +160,22 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
state_dtype
=
torch
.
float
if
self
.
fp32_optimizer_states
else
p
.
dtype
state_dtype
=
torch
.
float
if
self
.
fp32_optimizer_states
else
p
.
dtype
# gradient momentums
# gradient momentums
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
,
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
,
dtype
=
state_dtype
,
device
=
device
)
dtype
=
state_dtype
,
device
=
device
)
#memory_format=torch.preserve_format)
#memory_format=torch.preserve_format)
# gradient variances
# gradient variances
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
,
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
,
dtype
=
state_dtype
,
device
=
device
)
dtype
=
state_dtype
,
device
=
device
)
#memory_format=torch.preserve_format)
#memory_format=torch.preserve_format)
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
if
fp16_param_groups
is
not
None
:
if
fp16_param_groups
is
not
None
:
self
.
ds_opt_adam
.
adam_update_copy
(
self
.
ds_opt_adam
.
adam_update_copy
(
self
.
opt_id
,
state
[
'step'
],
group
[
'lr'
],
beta1
,
beta2
,
self
.
opt_id
,
group
[
'eps'
],
group
[
'weight_decay'
],
group
[
'bias_correction'
],
state
[
'step'
],
p
.
data
,
p
.
grad
.
data
,
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
],
group
[
'lr'
],
fp16_param_groups
[
group_id
][
param_id
].
data
)
beta1
,
beta2
,
group
[
'eps'
],
group
[
'weight_decay'
],
group
[
'bias_correction'
],
p
.
data
,
p
.
grad
.
data
,
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
],
fp16_param_groups
[
group_id
][
param_id
].
data
)
else
:
else
:
self
.
ds_opt_adam
.
adam_update
(
self
.
opt_id
,
self
.
ds_opt_adam
.
adam_update
(
self
.
opt_id
,
state
[
'step'
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
state
[
'step'
],
group
[
'weight_decay'
],
group
[
'bias_correction'
],
p
.
data
,
p
.
grad
.
data
,
group
[
'lr'
],
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
])
beta1
,
beta2
,
group
[
'eps'
],
group
[
'weight_decay'
],
group
[
'bias_correction'
],
p
.
data
,
p
.
grad
.
data
,
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
])
return
loss
return
loss
deepspeed/ops/adam/fused_adam.py
View file @
5bcc463d
'''
# Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
'''
"""
import
torch
import
torch
from
.multi_tensor_apply
import
MultiTensorApply
from
.multi_tensor_apply
import
MultiTensorApply
...
@@ -47,12 +49,12 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -47,12 +49,12 @@ class FusedAdam(torch.optim.Optimizer):
.. _On the Convergence of Adam and Beyond:
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
https://openreview.net/forum?id=ryQu7f-RZ
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
params
,
params
,
lr
=
1e-3
,
lr
=
1e-3
,
bias_correction
=
True
,
bias_correction
=
True
,
betas
=
(
0.9
,
betas
=
(
0.9
,
0.999
),
0.999
),
eps
=
1e-8
,
eps
=
1e-8
,
adam_w_mode
=
True
,
adam_w_mode
=
True
,
weight_decay
=
0.
,
weight_decay
=
0.
,
...
@@ -61,11 +63,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -61,11 +63,7 @@ class FusedAdam(torch.optim.Optimizer):
if
amsgrad
:
if
amsgrad
:
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
)
super
(
FusedAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
FusedAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
set_grad_none
=
set_grad_none
self
.
set_grad_none
=
set_grad_none
...
@@ -83,12 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -83,12 +81,7 @@ class FusedAdam(torch.optim.Optimizer):
else
:
else
:
super
(
FusedAdam
,
self
).
zero_grad
()
super
(
FusedAdam
,
self
).
zero_grad
()
def
step
(
self
,
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
None
,
grad_norms
=
None
):
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
None
,
grad_norms
=
None
):
"""Performs a single optimization step.
"""Performs a single optimization step.
Arguments:
Arguments:
...
@@ -121,8 +114,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -121,8 +114,7 @@ class FusedAdam(torch.optim.Optimizer):
continue
continue
if
p
.
grad
.
data
.
is_sparse
:
if
p
.
grad
.
data
.
is_sparse
:
raise
RuntimeError
(
raise
RuntimeError
(
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
'FusedAdam does not support sparse gradients, please consider SparseAdam instead'
)
)
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
# State initialization
# State initialization
...
@@ -151,35 +143,13 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -151,35 +143,13 @@ class FusedAdam(torch.optim.Optimizer):
if
(
len
(
g_16
)
>
0
):
if
(
len
(
g_16
)
>
0
):
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
multi_tensor_applier
(
self
.
multi_tensor_adam
,
multi_tensor_applier
(
self
.
multi_tensor_adam
,
self
.
_dummy_overflow_buf
,
[
g_16
,
p_16
,
m_16
,
v_16
],
self
.
_dummy_overflow_buf
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
state
[
'step'
],
self
.
adam_w_mode
,
[
g_16
,
bias_correction
,
group
[
'weight_decay'
])
p_16
,
m_16
,
v_16
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
state
[
'step'
],
self
.
adam_w_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
(
len
(
g_32
)
>
0
):
if
(
len
(
g_32
)
>
0
):
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
multi_tensor_applier
(
self
.
multi_tensor_adam
,
multi_tensor_applier
(
self
.
multi_tensor_adam
,
self
.
_dummy_overflow_buf
,
[
g_32
,
p_32
,
m_32
,
v_32
],
self
.
_dummy_overflow_buf
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
state
[
'step'
],
self
.
adam_w_mode
,
[
g_32
,
bias_correction
,
group
[
'weight_decay'
])
p_32
,
m_32
,
v_32
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
state
[
'step'
],
self
.
adam_w_mode
,
bias_correction
,
group
[
'weight_decay'
])
return
loss
return
loss
deepspeed/ops/adam/multi_tensor_apply.py
View file @
5bcc463d
'''
# Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex
Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex, commit a109f85
This file is adapted from NVIDIA/apex, commit a109f85
'''
"""
class
MultiTensorApply
(
object
):
class
MultiTensorApply
(
object
):
def
__init__
(
self
,
chunk_size
):
def
__init__
(
self
,
chunk_size
):
self
.
chunk_size
=
chunk_size
self
.
chunk_size
=
chunk_size
...
...
deepspeed/ops/aio/__init__.py
View file @
5bcc463d
'''
# Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0
Licensed under the MIT license.
'''
# DeepSpeed Team
from
..op_builder
import
AsyncIOBuilder
from
..op_builder
import
AsyncIOBuilder
deepspeed/ops/lamb/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.fused_lamb
import
FusedLamb
from
.fused_lamb
import
FusedLamb
deepspeed/ops/lamb/fused_lamb.py
View file @
5bcc463d
'''
# Copyright (c) Microsoft Corporation.
Copyright 2019 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Copyright NVIDIA/apex
Copyright NVIDIA/apex
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
This file is adapted from NVIDIA/apex/optimizer/fused_adam and implements the LAMB optimizer
'''
"""
import
types
import
types
import
torch
import
torch
from
deepspeed.ops.op_builder
import
FusedLambBuilder
from
deepspeed.ops.op_builder
import
FusedLambBuilder
...
@@ -35,12 +37,12 @@ class FusedLamb(torch.optim.Optimizer):
...
@@ -35,12 +37,12 @@ class FusedLamb(torch.optim.Optimizer):
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01)
amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb!
amsgrad (boolean, optional): NOT SUPPORTED in FusedLamb!
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
params
,
params
,
lr
=
1e-3
,
lr
=
1e-3
,
bias_correction
=
True
,
bias_correction
=
True
,
betas
=
(
0.9
,
betas
=
(
0.9
,
0.999
),
0.999
),
eps
=
1e-8
,
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
weight_decay
=
0.
,
...
@@ -64,12 +66,7 @@ class FusedLamb(torch.optim.Optimizer):
...
@@ -64,12 +66,7 @@ class FusedLamb(torch.optim.Optimizer):
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
lamb_coeffs
=
[]
self
.
lamb_coeffs
=
[]
def
step
(
self
,
def
step
(
self
,
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
):
closure
=
None
,
grads
=
None
,
output_params
=
None
,
scale
=
1.
,
grad_norms
=
None
):
"""Performs a single optimization step.
"""Performs a single optimization step.
Arguments:
Arguments:
...
@@ -114,7 +111,8 @@ class FusedLamb(torch.optim.Optimizer):
...
@@ -114,7 +111,8 @@ class FusedLamb(torch.optim.Optimizer):
#remove the previous coeffs
#remove the previous coeffs
del
self
.
lamb_coeffs
[:]
del
self
.
lamb_coeffs
[:]
for
group
,
grads_this_group
,
output_params_this_group
,
grad_norm_group
in
zip
(
self
.
param_groups
,
grads_group
,
output_params_group
,
grad_norms
):
for
group
,
grads_this_group
,
output_params_this_group
,
grad_norm_group
in
zip
(
self
.
param_groups
,
grads_group
,
output_params_group
,
grad_norms
):
if
grads_this_group
is
None
:
if
grads_this_group
is
None
:
grads_this_group
=
[
None
]
*
len
(
group
[
'params'
])
grads_this_group
=
[
None
]
*
len
(
group
[
'params'
])
if
output_params_this_group
is
None
:
if
output_params_this_group
is
None
:
...
@@ -127,7 +125,8 @@ class FusedLamb(torch.optim.Optimizer):
...
@@ -127,7 +125,8 @@ class FusedLamb(torch.optim.Optimizer):
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
for
p
,
grad
,
output_param
,
grad_norm
in
zip
(
group
[
'params'
],
grads_this_group
,
output_params_this_group
,
grad_norm_group
):
for
p
,
grad
,
output_param
,
grad_norm
in
zip
(
group
[
'params'
],
grads_this_group
,
output_params_this_group
,
grad_norm_group
):
# compute combined scale factor for this group
# compute combined scale factor for this group
combined_scale
=
scale
combined_scale
=
scale
...
@@ -162,24 +161,10 @@ class FusedLamb(torch.optim.Optimizer):
...
@@ -162,24 +161,10 @@ class FusedLamb(torch.optim.Optimizer):
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
out_p
=
torch
.
tensor
(
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
[],
lamb_coeff
=
self
.
fused_lamb_cuda
.
lamb
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
beta2
,
max_coeff
,
min_coeff
,
group
[
'eps'
],
combined_scale
,
lamb_coeff
=
self
.
fused_lamb_cuda
.
lamb
(
p
.
data
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
max_coeff
,
min_coeff
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
group
[
'weight_decay'
])
self
.
lamb_coeffs
.
append
(
lamb_coeff
)
self
.
lamb_coeffs
.
append
(
lamb_coeff
)
return
loss
return
loss
...
...
deepspeed/ops/module_inject.py
deleted
100755 → 0
View file @
ac5fbab4
import
copy
import
torch
import
deepspeed
from
deepspeed.ops
import
DeepSpeedTransformerConfig
def
_copy_child_transformer_state
(
new_module
,
orig_child
,
pre_layer_norm
):
# copy relevant state from original child -> new module
qw
=
orig_child
.
attention
.
self
.
query
.
weight
qb
=
orig_child
.
attention
.
self
.
query
.
bias
kw
=
orig_child
.
attention
.
self
.
key
.
weight
kb
=
orig_child
.
attention
.
self
.
key
.
bias
vw
=
orig_child
.
attention
.
self
.
value
.
weight
vb
=
orig_child
.
attention
.
self
.
value
.
bias
qkvw
=
torch
.
cat
((
qw
,
kw
,
vw
),
0
)
qkvb
=
torch
.
cat
((
qb
,
kb
,
vb
),
0
)
#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)
new_module
.
attn_qkvw
.
data
=
qkvw
new_module
.
attn_qkvb
.
data
=
qkvb
new_module
.
attn_ow
.
data
=
orig_child
.
attention
.
output
.
dense
.
weight
new_module
.
attn_ob
.
data
=
orig_child
.
attention
.
output
.
dense
.
bias
if
pre_layer_norm
:
attention_layernorm
=
orig_child
.
PostAttentionLayerNorm
else
:
attention_layernorm
=
orig_child
.
attention
.
output
.
LayerNorm
new_module
.
attn_nw
.
data
=
attention_layernorm
.
weight
new_module
.
attn_nb
.
data
=
attention_layernorm
.
bias
if
pre_layer_norm
:
intermediate_ff
=
orig_child
.
intermediate
.
dense_act
else
:
intermediate_ff
=
orig_child
.
intermediate
.
dense
new_module
.
inter_w
.
data
=
intermediate_ff
.
weight
new_module
.
inter_b
.
data
=
intermediate_ff
.
bias
new_module
.
output_w
.
data
=
orig_child
.
output
.
dense
.
weight
new_module
.
output_b
.
data
=
orig_child
.
output
.
dense
.
bias
if
pre_layer_norm
:
transformer_layernorm
=
orig_child
.
PreAttentionLayerNorm
else
:
transformer_layernorm
=
orig_child
.
output
.
LayerNorm
new_module
.
norm_w
.
data
=
transformer_layernorm
.
weight
new_module
.
norm_b
.
data
=
transformer_layernorm
.
bias
def
_replace_transformer_layer
(
orig_layer_impl
,
model
,
transformer_config
):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
transformer_config (dict): deepspeed transformer layer config containing hidden size, attention heads, etc.
Returns:
Updated nn.module with replaced transformer layers
"""
def
replace_fn
(
child
):
new_module
=
deepspeed
.
DeepSpeedTransformerLayer
(
transformer_config
)
_copy_child_transformer_state
(
new_module
,
child
,
transformer_config
.
pre_layer_norm
)
return
new_module
return
_replace_module
(
model
=
model
,
orig_class
=
orig_layer_impl
,
replace_fn
=
replace_fn
)
def
replace_module
(
orig_module_impl
,
model
,
replacement_module_config
):
""" Replace client module
Arguments:
orig_module_impl (torch.nn.Module): original module implementation to replace,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
replacement_module_config (dict): deepspeed replacement module config (e.g., DeepSpeedTransformerConfig) .
Returns:
Updated nn.module with replaced modules
"""
assert
isinstance
(
replacement_module_config
,
DeepSpeedTransformerConfig
),
\
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return
_replace_transformer_layer
(
orig_layer_impl
=
orig_module_impl
,
model
=
model
,
transformer_config
=
replacement_module_config
)
def
_revert_transformer_layer
(
orig_layer_impl
,
model
,
bert_config
,
transformer_config
):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.
transformer_config (dict): deepspeed tranformer config used for replacement
Returns:
Updated nn.module with original bert-style transformer layers
"""
def
replace_fn
(
child
):
#from turing.nvidia_modelingpreln import BertLayer
orig_module
=
orig_layer_impl
(
bert_config
)
# copy relevant state from child -> original module
qkvw
=
child
.
attn_qkvw
.
data
qkvb
=
child
.
attn_qkvb
.
data
qw
,
kw
,
vw
=
torch
.
chunk
(
qkvw
,
3
,
axis
=
0
)
qb
,
kb
,
vb
=
torch
.
chunk
(
qkvb
,
3
,
axis
=
0
)
orig_module
.
attention
.
self
.
query
.
weight
.
data
=
qw
orig_module
.
attention
.
self
.
query
.
bias
.
data
=
qb
orig_module
.
attention
.
self
.
key
.
weight
.
data
=
kw
orig_module
.
attention
.
self
.
key
.
bias
.
data
=
kb
orig_module
.
attention
.
self
.
value
.
weight
.
data
=
vw
orig_module
.
attention
.
self
.
value
.
bias
.
data
=
vb
orig_module
.
attention
.
output
.
dense
.
weight
.
data
=
child
.
attn_ow
.
data
orig_module
.
attention
.
output
.
dense
.
bias
.
data
=
child
.
attn_ob
.
data
attn_ln_w
=
child
.
attn_nw
.
data
attn_ln_b
=
child
.
attn_nb
.
data
if
transformer_config
.
pre_layer_norm
:
orig_module
.
PostAttentionLayerNorm
.
weight
.
data
=
attn_ln_w
orig_module
.
PostAttentionLayerNorm
.
bias
.
data
=
attn_ln_b
else
:
orig_module
.
attention
.
output
.
LayerNorm
.
weight
.
data
=
attn_ln_w
orig_module
.
attention
.
output
.
LayerNorm
.
bias
.
data
=
attn_ln_b
inter_ff_w
=
child
.
inter_w
.
data
inter_ff_b
=
child
.
inter_b
.
data
if
transformer_config
.
pre_layer_norm
:
orig_module
.
intermediate
.
dense_act
.
weight
.
data
=
inter_ff_w
orig_module
.
intermediate
.
dense_act
.
bias
.
data
=
inter_ff_b
else
:
orig_module
.
intermediate
.
dense
.
weight
.
data
=
inter_ff_w
orig_module
.
intermediate
.
dense
.
bias
.
data
=
inter_ff_b
orig_module
.
output
.
dense
.
weight
.
data
=
child
.
output_w
.
data
orig_module
.
output
.
dense
.
bias
.
data
=
child
.
output_b
.
data
transformer_ln_w
=
child
.
norm_w
.
data
transformer_ln_b
=
child
.
norm_b
.
data
if
transformer_config
.
pre_layer_norm
:
orig_module
.
PreAttentionLayerNorm
.
weight
.
data
=
transformer_ln_w
orig_module
.
PreAttentionLayerNorm
.
bias
.
data
=
transformer_ln_b
else
:
orig_module
.
output
.
LayerNorm
.
weight
.
data
=
transformer_ln_w
orig_module
.
output
.
LayerNorm
.
bias
.
data
=
transformer_ln_b
return
orig_module
return
_replace_module
(
model
=
model
,
orig_class
=
deepspeed
.
DeepSpeedTransformerLayer
,
replace_fn
=
replace_fn
)
def
revert_module
(
orig_module_impl
,
model
,
orig_module_config
,
replacement_module_config
):
""" Revert DeepSpeed's module back to original client module
Arguments:
orig_module_impl (torch.nn.Module): the original module that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
orig_module_config (dict): original module configuration
replacement_module_config (dict): replacement deepspeed module configuration
Returns:
Updated nn.module with original bert-style transformer layers
"""
assert
isinstance
(
replacement_module_config
,
DeepSpeedTransformerConfig
),
\
'Only DeepSpeedTransformerConfig is currently supported as replacement config'
return
_revert_transformer_layer
(
orig_layer_impl
=
orig_module_impl
,
model
=
model
,
bert_config
=
orig_module_config
,
transformer_config
=
replacement_module_config
)
def
_replace_module
(
model
,
orig_class
,
replace_fn
):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.
Returns:
A modified ``model``.
"""
policy
=
{
orig_class
:
replace_fn
}
return
_replace_module_using_policies
(
model
,
policy
)
def
_replace_module_using_policies
(
model
,
policies
):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.
Returns:
Modified ``model``.
"""
for
name
,
child
in
model
.
named_children
():
if
child
.
__class__
in
policies
:
orig
=
repr
(
child
)
setattr
(
model
,
name
,
policies
[
child
.
__class__
](
child
))
new
=
getattr
(
model
,
name
)
else
:
_replace_module_using_policies
(
child
,
policies
)
return
model
deepspeed/ops/quantizer/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.quantizer
import
ds_quantizer
from
.quantizer
import
ds_quantizer
deepspeed/ops/quantizer/quantizer.py
View file @
5bcc463d
'''
# Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
'''
# DeepSpeed Team
import
torch
import
torch
from
deepspeed.ops.op_builder
import
QuantizerBuilder
from
deepspeed.ops.op_builder
import
QuantizerBuilder
...
...
deepspeed/ops/random_ltd/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.dropping_utils
import
gpt_sample_tokens
,
bert_sample_tokens
,
GatherTokens
,
ScatterTokens
from
.dropping_utils
import
gpt_sample_tokens
,
bert_sample_tokens
,
GatherTokens
,
ScatterTokens
deepspeed/ops/random_ltd/dropping_utils.py
View file @
5bcc463d
"""
# Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
import
torch
import
torch
from
deepspeed.ops.op_builder
import
RandomLTDBuilder
from
deepspeed.ops.op_builder
import
RandomLTDBuilder
...
@@ -23,9 +25,7 @@ def gpt_sample_tokens(reserved_length: int,
...
@@ -23,9 +25,7 @@ def gpt_sample_tokens(reserved_length: int,
prob_dist
=
torch
.
ones
((
layers
*
batch_size
,
seq_length
),
device
=
device
)
prob_dist
=
torch
.
ones
((
layers
*
batch_size
,
seq_length
),
device
=
device
)
sampled_indices
=
torch
.
multinomial
(
prob_dist
,
reserved_length
)
sampled_indices
=
torch
.
multinomial
(
prob_dist
,
reserved_length
)
sampled_indices
=
sampled_indices
.
reshape
(
layers
,
sampled_indices
=
sampled_indices
.
reshape
(
layers
,
batch_size
,
reserved_length
).
to
(
torch
.
int32
)
batch_size
,
reserved_length
).
to
(
torch
.
int32
)
global
random_ltd_module
global
random_ltd_module
if
random_ltd_module
is
None
:
if
random_ltd_module
is
None
:
random_ltd_module
=
RandomLTDBuilder
().
load
()
random_ltd_module
=
RandomLTDBuilder
().
load
()
...
@@ -59,9 +59,7 @@ def bert_sample_tokens(reserved_length: int,
...
@@ -59,9 +59,7 @@ def bert_sample_tokens(reserved_length: int,
prob_dist
=
torch
.
ones
((
layers
*
batch_size
,
seq_length
),
device
=
device
)
prob_dist
=
torch
.
ones
((
layers
*
batch_size
,
seq_length
),
device
=
device
)
sampled_indices
=
torch
.
multinomial
(
prob_dist
,
reserved_length
)
sampled_indices
=
torch
.
multinomial
(
prob_dist
,
reserved_length
)
sampled_indices
=
sampled_indices
.
reshape
(
layers
,
sampled_indices
=
sampled_indices
.
reshape
(
layers
,
batch_size
,
reserved_length
).
to
(
torch
.
int32
)
batch_size
,
reserved_length
).
to
(
torch
.
int32
)
global
random_ltd_module
global
random_ltd_module
if
random_ltd_module
is
None
:
if
random_ltd_module
is
None
:
random_ltd_module
=
RandomLTDBuilder
().
load
()
random_ltd_module
=
RandomLTDBuilder
().
load
()
...
@@ -82,11 +80,9 @@ def bert_sample_tokens(reserved_length: int,
...
@@ -82,11 +80,9 @@ def bert_sample_tokens(reserved_length: int,
class
GatherTokens
(
torch
.
autograd
.
Function
):
class
GatherTokens
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
def
forward
(
ctx
,
activations
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
batch_first
:
bool
):
activations
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
batch_first
:
bool
):
global
random_ltd_module
global
random_ltd_module
if
random_ltd_module
is
None
:
if
random_ltd_module
is
None
:
random_ltd_module
=
RandomLTDBuilder
().
load
()
random_ltd_module
=
RandomLTDBuilder
().
load
()
...
@@ -104,25 +100,18 @@ class GatherTokens(torch.autograd.Function):
...
@@ -104,25 +100,18 @@ class GatherTokens(torch.autograd.Function):
activations
,
sorted_indices
=
ctx
.
saved_tensors
activations
,
sorted_indices
=
ctx
.
saved_tensors
batch_first
=
ctx
.
batch_first
batch_first
=
ctx
.
batch_first
return
random_ltd_module
.
token_scatter_
(
a_gradients
,
return
random_ltd_module
.
token_scatter_
(
a_gradients
,
g_gradients
,
sorted_indices
,
batch_first
),
None
,
None
g_gradients
,
sorted_indices
,
batch_first
),
None
,
None
class
ScatterTokens
(
torch
.
autograd
.
Function
):
class
ScatterTokens
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
def
forward
(
ctx
,
all_activations
:
torch
.
Tensor
,
layer_activations
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
all_activations
:
torch
.
Tensor
,
layer_activations
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
batch_first
:
bool
):
batch_first
:
bool
):
global
random_ltd_module
global
random_ltd_module
if
random_ltd_module
is
None
:
if
random_ltd_module
is
None
:
random_ltd_module
=
RandomLTDBuilder
().
load
()
random_ltd_module
=
RandomLTDBuilder
().
load
()
scatter_results
=
random_ltd_module
.
token_scatter_
(
all_activations
.
clone
(),
scatter_results
=
random_ltd_module
.
token_scatter_
(
all_activations
.
clone
(),
layer_activations
,
sorted_indices
,
layer_activations
,
sorted_indices
,
batch_first
)
batch_first
)
ctx
.
save_for_backward
(
sorted_indices
)
ctx
.
save_for_backward
(
sorted_indices
)
...
@@ -139,7 +128,5 @@ class ScatterTokens(torch.autograd.Function):
...
@@ -139,7 +128,5 @@ class ScatterTokens(torch.autograd.Function):
sorted_indices
,
=
ctx
.
saved_tensors
sorted_indices
,
=
ctx
.
saved_tensors
batch_first
=
ctx
.
batch_first
batch_first
=
ctx
.
batch_first
ret_val
=
random_ltd_module
.
token_gather
(
out_gradients
,
ret_val
=
random_ltd_module
.
token_gather
(
out_gradients
,
sorted_indices
,
batch_first
)
sorted_indices
,
batch_first
)
return
out_gradients
,
ret_val
,
None
,
None
return
out_gradients
,
ret_val
,
None
,
None
deepspeed/ops/sparse_attention/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.sparsity_config
import
SparsityConfig
,
DenseSparsityConfig
,
FixedSparsityConfig
,
VariableSparsityConfig
,
BigBirdSparsityConfig
,
BSLongformerSparsityConfig
,
LocalSlidingWindowSparsityConfig
from
.sparsity_config
import
SparsityConfig
,
DenseSparsityConfig
,
FixedSparsityConfig
,
VariableSparsityConfig
,
BigBirdSparsityConfig
,
BSLongformerSparsityConfig
,
LocalSlidingWindowSparsityConfig
from
.sparse_self_attention
import
SparseSelfAttention
from
.sparse_self_attention
import
SparseSelfAttention
...
...
deepspeed/ops/sparse_attention/bert_sparse_self_attention.py
View file @
5bcc463d
"""
# Copyright (c) Microsoft Corporation.
Copyright 2020 The Microsoft DeepSpeed Team
# SPDX-License-Identifier: Apache-2.0
"""
# DeepSpeed Team
from
torch
import
nn
from
torch
import
nn
from
deepspeed.ops.sparse_attention
import
SparseSelfAttention
,
FixedSparsityConfig
from
deepspeed.ops.sparse_attention
import
SparseSelfAttention
,
FixedSparsityConfig
...
@@ -13,6 +14,7 @@ class BertSparseSelfAttention(nn.Module):
...
@@ -13,6 +14,7 @@ class BertSparseSelfAttention(nn.Module):
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
config
,
config
,
...
@@ -29,10 +31,8 @@ class BertSparseSelfAttention(nn.Module):
...
@@ -29,10 +31,8 @@ class BertSparseSelfAttention(nn.Module):
super
(
BertSparseSelfAttention
,
self
).
__init__
()
super
(
BertSparseSelfAttention
,
self
).
__init__
()
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
if
config
.
hidden_size
%
config
.
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
"heads (%d)"
%
(
config
.
hidden_size
,
config
.
num_attention_heads
))
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
attention_head_size
=
int
(
config
.
hidden_size
/
config
.
num_attention_heads
)
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
self
.
all_head_size
=
self
.
num_attention_heads
*
self
.
attention_head_size
...
@@ -44,8 +44,7 @@ class BertSparseSelfAttention(nn.Module):
...
@@ -44,8 +44,7 @@ class BertSparseSelfAttention(nn.Module):
self
.
sparse_self_attention
=
SparseSelfAttention
(
sparsity_config
)
self
.
sparse_self_attention
=
SparseSelfAttention
(
sparsity_config
)
def
transpose_for_scores
(
self
,
x
):
def
transpose_for_scores
(
self
,
x
):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
attention_head_size
)
self
.
attention_head_size
)
x
=
x
.
view
(
*
new_x_shape
)
x
=
x
.
view
(
*
new_x_shape
)
return
x
.
permute
(
0
,
2
,
1
,
3
)
return
x
.
permute
(
0
,
2
,
1
,
3
)
...
...
deepspeed/ops/sparse_attention/matmul.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
# DeepSpeed note, code taken & adapted from commit 9aa94789f13ada713af36cfd8cca2fc9a7f6b79a
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
# https://github.com/ptillet/torch-blocksparse/blob/master/torch_blocksparse/matmul.py
...
@@ -12,29 +15,8 @@ from deepspeed.accelerator import get_accelerator
...
@@ -12,29 +15,8 @@ from deepspeed.accelerator import get_accelerator
@
triton
.
jit
@
triton
.
jit
def
_kernel
(
A
,
def
_kernel
(
A
,
B
,
C
,
stride_za
,
stride_ha
,
stride_ma
,
stride_ka
,
stride_zb
,
stride_hb
,
stride_kb
,
stride_nb
,
stride_zc
,
B
,
stride_hc
,
stride_mc
,
stride_nc
,
DS0
,
DS1
,
SDD_K
,
SDD_off_width
,
lut
,
locks
,
nlocks
,
**
meta
):
C
,
stride_za
,
stride_ha
,
stride_ma
,
stride_ka
,
stride_zb
,
stride_hb
,
stride_kb
,
stride_nb
,
stride_zc
,
stride_hc
,
stride_mc
,
stride_nc
,
DS0
,
DS1
,
SDD_K
,
SDD_off_width
,
lut
,
locks
,
nlocks
,
**
meta
):
TM
=
meta
[
'TM'
]
TM
=
meta
[
'TM'
]
TN
=
meta
[
'TN'
]
TN
=
meta
[
'TN'
]
TK
=
meta
[
'TK'
]
TK
=
meta
[
'TK'
]
...
@@ -194,8 +176,7 @@ def _kernel(A,
...
@@ -194,8 +176,7 @@ def _kernel(A,
tl
.
store
(
pc
,
c
,
mask
=
checkc
)
tl
.
store
(
pc
,
c
,
mask
=
checkc
)
# accumulate partial results using spin-locks
# accumulate partial results using spin-locks
else
:
else
:
plock
=
locks
+
tl
.
program_id
(
2
)
*
nlocks
*
tl
.
num_programs
(
1
)
+
tl
.
program_id
(
plock
=
locks
+
tl
.
program_id
(
2
)
*
nlocks
*
tl
.
num_programs
(
1
)
+
tl
.
program_id
(
1
)
*
nlocks
+
lockid
-
1
1
)
*
nlocks
+
lockid
-
1
pcount
=
plock
+
tl
.
num_programs
(
2
)
*
tl
.
num_programs
(
1
)
*
nlocks
pcount
=
plock
+
tl
.
num_programs
(
2
)
*
tl
.
num_programs
(
1
)
*
nlocks
while
tl
.
atomic_cas
(
plock
,
0
,
1
)
==
1
:
while
tl
.
atomic_cas
(
plock
,
0
,
1
)
==
1
:
pass
pass
...
@@ -292,10 +273,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -292,10 +273,7 @@ class _sparse_matmul(torch.autograd.Function):
#segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
#segmented = _sparse_matmul.sdd_segment(layout.type(torch.int32), start_width)
start_width
=
(
128
if
block
>
16
else
32
)
//
block
start_width
=
(
128
if
block
>
16
else
32
)
//
block
layout
=
layout
.
type
(
torch
.
int32
)
layout
=
layout
.
type
(
torch
.
int32
)
segmented
=
libtriton
.
superblock
(
layout
.
data_ptr
(),
segmented
=
libtriton
.
superblock
(
layout
.
data_ptr
(),
layout
.
shape
[
0
],
layout
.
shape
[
1
],
layout
.
shape
[
2
],
layout
.
shape
[
0
],
layout
.
shape
[
1
],
layout
.
shape
[
2
],
start_width
)
start_width
)
luts
,
widths
,
packs
=
[],
[],
[]
luts
,
widths
,
packs
=
[],
[],
[]
for
size
,
nnz
in
segmented
:
for
size
,
nnz
in
segmented
:
...
@@ -317,19 +295,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -317,19 +295,7 @@ class _sparse_matmul(torch.autograd.Function):
return
luts
,
None
,
widths
,
packs
return
luts
,
None
,
widths
,
packs
@
staticmethod
@
staticmethod
def
_sdd_matmul
(
a
,
def
_sdd_matmul
(
a
,
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
luts
,
num_locks
,
widths
,
packs
,
bench
,
time
):
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
luts
,
num_locks
,
widths
,
packs
,
bench
,
time
):
if
trans_c
:
if
trans_c
:
a
,
b
=
b
,
a
a
,
b
=
b
,
a
trans_a
,
trans_b
=
not
trans_b
,
not
trans_a
trans_a
,
trans_b
=
not
trans_b
,
not
trans_a
...
@@ -339,9 +305,8 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -339,9 +305,8 @@ class _sparse_matmul(torch.autograd.Function):
b_dim
=
-
1
if
trans_b
else
-
2
b_dim
=
-
1
if
trans_b
else
-
2
a_inner
,
b_inner
=
a
.
shape
[
a_dim
],
b
.
shape
[
b_dim
]
a_inner
,
b_inner
=
a
.
shape
[
a_dim
],
b
.
shape
[
b_dim
]
if
a_inner
!=
b_inner
:
if
a_inner
!=
b_inner
:
raise
ValueError
(
raise
ValueError
(
f
"Size of tensor A along the
{
a_dim
}
dim (
{
a_inner
}
) must match size "
f
"Size of tensor A along the
{
a_dim
}
dim (
{
a_inner
}
) must match size "
f
"of tensor B along the
{
b_dim
}
dim (
{
b_inner
}
)"
)
f
"of tensor B along the
{
b_dim
}
dim (
{
b_inner
}
)"
)
if
a_inner
%
16
!=
0
:
if
a_inner
%
16
!=
0
:
raise
ValueError
(
'Reduction size for SDD must be a multiple of 16'
)
raise
ValueError
(
'Reduction size for SDD must be a multiple of 16'
)
...
@@ -356,12 +321,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -356,12 +321,7 @@ class _sparse_matmul(torch.autograd.Function):
device
=
a
.
device
device
=
a
.
device
# create kernel
# create kernel
total_width
=
sum
([
width
*
pack
*
pack
for
width
,
pack
in
zip
(
widths
,
packs
)])
total_width
=
sum
([
width
*
pack
*
pack
for
width
,
pack
in
zip
(
widths
,
packs
)])
c
=
torch
.
empty
((
batch_size
,
c
=
torch
.
empty
((
batch_size
,
total_width
,
block
,
block
),
dtype
=
dtype
,
device
=
a
.
device
)
total_width
,
block
,
block
),
dtype
=
dtype
,
device
=
a
.
device
)
for
lut
,
width
,
pack
in
zip
(
luts
,
widths
,
packs
):
for
lut
,
width
,
pack
in
zip
(
luts
,
widths
,
packs
):
F32TK
=
[
8
,
16
]
F32TK
=
[
8
,
16
]
F16TK
=
[
16
]
F16TK
=
[
16
]
...
@@ -387,12 +347,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -387,12 +347,7 @@ class _sparse_matmul(torch.autograd.Function):
max_width
=
49152
max_width
=
49152
total
=
0
if
bench
else
None
total
=
0
if
bench
else
None
for
off_width
in
range
(
0
,
width
,
max_width
):
for
off_width
in
range
(
0
,
width
,
max_width
):
grid
=
lambda
meta
:
[
grid
=
lambda
meta
:
[
meta
[
'TZ'
],
min
(
max_width
,
width
-
off_width
),
batch_size
]
meta
[
'TZ'
],
min
(
max_width
,
width
-
off_width
),
batch_size
]
_kernel
[
grid
](
a
,
_kernel
[
grid
](
a
,
b
,
b
,
c
,
c
,
...
@@ -504,13 +459,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -504,13 +459,7 @@ class _sparse_matmul(torch.autograd.Function):
# create header
# create header
width
=
column
.
size
(
0
)
width
=
column
.
size
(
0
)
offsets
+=
6
*
width
offsets
+=
6
*
width
header
=
torch
.
stack
((
offsets
,
header
=
torch
.
stack
((
offsets
,
segments
,
column
,
depth
,
lockid
,
maxid
),
dim
=
1
).
view
(
-
1
).
contiguous
()
segments
,
column
,
depth
,
lockid
,
maxid
),
dim
=
1
).
view
(
-
1
).
contiguous
()
incs
=
torch
.
stack
((
xincs
,
wincs
),
dim
=
1
).
view
(
-
1
).
contiguous
()
incs
=
torch
.
stack
((
xincs
,
wincs
),
dim
=
1
).
view
(
-
1
).
contiguous
()
incs
=
torch
.
cat
((
incs
,
torch
.
zeros
(
2
,
device
=
incs
.
device
,
dtype
=
incs
.
dtype
)))
incs
=
torch
.
cat
((
incs
,
torch
.
zeros
(
2
,
device
=
incs
.
device
,
dtype
=
incs
.
dtype
)))
# create lut
# create lut
...
@@ -521,19 +470,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -521,19 +470,7 @@ class _sparse_matmul(torch.autograd.Function):
return
lut
,
num_locks
,
width
,
None
return
lut
,
num_locks
,
width
,
None
@
staticmethod
@
staticmethod
def
_dds_matmul
(
a
,
def
_dds_matmul
(
a
,
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
lut
,
num_locks
,
width
,
packs
,
bench
,
time
):
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
lut
,
num_locks
,
width
,
packs
,
bench
,
time
):
global
triton
global
triton
if
triton
is
None
:
if
triton
is
None
:
triton
=
importlib
.
import_module
(
'triton'
)
triton
=
importlib
.
import_module
(
'triton'
)
...
@@ -548,16 +485,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -548,16 +485,7 @@ class _sparse_matmul(torch.autograd.Function):
BS2
=
block
*
spdims
[
1
if
trans_b
else
2
]
BS2
=
block
*
spdims
[
1
if
trans_b
else
2
]
dtype
=
a
.
dtype
dtype
=
a
.
dtype
# kernel
# kernel
meta
=
{
meta
=
{
'TN'
:
block
,
'TM'
:
128
,
'TK'
:
16
,
'BLOCK'
:
block
,
'TZ'
:
1
,
'SDD'
:
False
,
'DSD'
:
False
,
'DDS'
:
True
}
'TN'
:
block
,
'TM'
:
128
,
'TK'
:
16
,
'BLOCK'
:
block
,
'TZ'
:
1
,
'SDD'
:
False
,
'DSD'
:
False
,
'DDS'
:
True
}
# output
# output
CS0
=
AS0
CS0
=
AS0
CS1
=
AS1
CS1
=
AS1
...
@@ -593,19 +521,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -593,19 +521,7 @@ class _sparse_matmul(torch.autograd.Function):
return
c
return
c
@
staticmethod
@
staticmethod
def
_dsd_matmul
(
a
,
def
_dsd_matmul
(
a
,
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
lut
,
num_locks
,
width
,
packs
,
bench
,
time
):
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
lut
,
num_locks
,
width
,
packs
,
bench
,
time
):
global
triton
global
triton
if
triton
is
None
:
if
triton
is
None
:
triton
=
importlib
.
import_module
(
'triton'
)
triton
=
importlib
.
import_module
(
'triton'
)
...
@@ -621,16 +537,7 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -621,16 +537,7 @@ class _sparse_matmul(torch.autograd.Function):
dtype
=
a
.
dtype
dtype
=
a
.
dtype
# kernel
# kernel
meta
=
{
meta
=
{
'TM'
:
block
,
'TN'
:
128
,
'TK'
:
16
,
'BLOCK'
:
block
,
'TZ'
:
1
,
'SDD'
:
False
,
'DSD'
:
True
,
'DDS'
:
False
}
'TM'
:
block
,
'TN'
:
128
,
'TK'
:
16
,
'BLOCK'
:
block
,
'TZ'
:
1
,
'SDD'
:
False
,
'DSD'
:
True
,
'DDS'
:
False
}
# output
# output
CS0
=
BS0
CS0
=
BS0
CS1
=
BS1
CS1
=
BS1
...
@@ -665,53 +572,14 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -665,53 +572,14 @@ class _sparse_matmul(torch.autograd.Function):
**
meta
)
**
meta
)
return
c
return
c
fn
=
{
fn
=
{
'sdd'
:
_sdd_matmul
.
__get__
(
object
),
'dsd'
:
_dsd_matmul
.
__get__
(
object
),
'dds'
:
_dds_matmul
.
__get__
(
object
)}
'sdd'
:
_sdd_matmul
.
__get__
(
object
),
'dsd'
:
_dsd_matmul
.
__get__
(
object
),
'dds'
:
_dds_matmul
.
__get__
(
object
)
}
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
def
forward
(
ctx
,
a
,
b
,
trans_a
,
trans_b
,
trans_c
,
mode
,
spdims
,
block
,
c_lut
,
c_num_locks
,
c_width
,
c_packs
,
a
,
c_bench
,
c_time
,
da_lut
,
da_num_locks
,
da_width
,
da_packs
,
da_bench
,
da_time
,
db_lut
,
db_num_locks
,
b
,
db_width
,
db_packs
,
db_bench
,
db_time
):
trans_a
,
c
=
_sparse_matmul
.
fn
[
mode
](
a
,
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
c_lut
,
c_num_locks
,
c_width
,
trans_b
,
c_packs
,
c_bench
,
c_time
)
trans_c
,
mode
,
spdims
,
block
,
c_lut
,
c_num_locks
,
c_width
,
c_packs
,
c_bench
,
c_time
,
da_lut
,
da_num_locks
,
da_width
,
da_packs
,
da_bench
,
da_time
,
db_lut
,
db_num_locks
,
db_width
,
db_packs
,
db_bench
,
db_time
):
c
=
_sparse_matmul
.
fn
[
mode
](
a
,
b
,
trans_a
,
trans_b
,
trans_c
,
spdims
,
block
,
c_lut
,
c_num_locks
,
c_width
,
c_packs
,
c_bench
,
c_time
)
# save for backward
# save for backward
ctx
.
save_for_backward
(
a
,
b
)
ctx
.
save_for_backward
(
a
,
b
)
ctx
.
da_num_locks
=
da_num_locks
ctx
.
da_num_locks
=
da_num_locks
...
@@ -741,34 +609,14 @@ class _sparse_matmul(torch.autograd.Function):
...
@@ -741,34 +609,14 @@ class _sparse_matmul(torch.autograd.Function):
# gradients w.r.t. a
# gradients w.r.t. a
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
mode_da
=
mode
[
1
]
+
mode
[
0
]
+
mode
[
2
]
mode_da
=
mode
[
1
]
+
mode
[
0
]
+
mode
[
2
]
da
=
_sparse_matmul
.
fn
[
mode_da
](
dc
,
da
=
_sparse_matmul
.
fn
[
mode_da
](
dc
,
b
,
False
,
not
ctx
.
trans_b
,
ctx
.
trans_a
,
ctx
.
spdims
,
ctx
.
block
,
b
,
ctx
.
da_lut
,
ctx
.
da_num_locks
,
ctx
.
da_width
,
ctx
.
da_packs
,
ctx
.
da_bench
,
False
,
not
ctx
.
trans_b
,
ctx
.
trans_a
,
ctx
.
spdims
,
ctx
.
block
,
ctx
.
da_lut
,
ctx
.
da_num_locks
,
ctx
.
da_width
,
ctx
.
da_packs
,
ctx
.
da_bench
,
ctx
.
da_time
)
ctx
.
da_time
)
# gradients w.r.t. b
# gradients w.r.t. b
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
mode_db
=
mode
[
2
]
+
mode
[
1
]
+
mode
[
0
]
mode_db
=
mode
[
2
]
+
mode
[
1
]
+
mode
[
0
]
db
=
_sparse_matmul
.
fn
[
mode_db
](
a
,
db
=
_sparse_matmul
.
fn
[
mode_db
](
a
,
dc
,
not
ctx
.
trans_a
,
False
,
ctx
.
trans_b
,
ctx
.
spdims
,
ctx
.
block
,
dc
,
ctx
.
db_lut
,
ctx
.
db_num_locks
,
ctx
.
db_width
,
ctx
.
db_packs
,
ctx
.
db_bench
,
not
ctx
.
trans_a
,
False
,
ctx
.
trans_b
,
ctx
.
spdims
,
ctx
.
block
,
ctx
.
db_lut
,
ctx
.
db_num_locks
,
ctx
.
db_width
,
ctx
.
db_packs
,
ctx
.
db_bench
,
ctx
.
db_time
)
ctx
.
db_time
)
return
da
,
db
,
None
,
None
,
None
,
\
return
da
,
db
,
None
,
None
,
None
,
\
None
,
None
,
None
,
None
,
\
None
,
None
,
None
,
None
,
\
...
@@ -785,6 +633,7 @@ class MatMul:
...
@@ -785,6 +633,7 @@ class MatMul:
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
For more details about sparsity config, please see `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
"""
"""
def
make_lut
(
self
,
dtype
,
device
):
def
make_lut
(
self
,
dtype
,
device
):
"""Generates the sparsity layout/s used in block-sparse matmul
"""Generates the sparsity layout/s used in block-sparse matmul
"""
"""
...
@@ -797,21 +646,25 @@ class MatMul:
...
@@ -797,21 +646,25 @@ class MatMul:
if
self
.
mode
==
'sdd'
:
if
self
.
mode
==
'sdd'
:
c_lut
,
c_num_locks
,
c_width
,
c_packs
=
_sparse_matmul
.
make_sdd_lut
(
layout
,
block
,
dtype
,
device
)
c_lut
,
c_num_locks
,
c_width
,
c_packs
=
_sparse_matmul
.
make_sdd_lut
(
layout
,
block
,
dtype
,
device
)
elif
self
.
mode
==
'dsd'
:
elif
self
.
mode
==
'dsd'
:
c_lut
,
c_num_locks
,
c_width
,
c_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
not
self
.
trans_a
,
device
)
c_lut
,
c_num_locks
,
c_width
,
c_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
not
self
.
trans_a
,
device
)
elif
self
.
mode
==
'dds'
:
elif
self
.
mode
==
'dds'
:
c_lut
,
c_num_locks
,
c_width
,
c_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
self
.
trans_b
,
device
)
c_lut
,
c_num_locks
,
c_width
,
c_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
self
.
trans_b
,
device
)
# DA look-up table
# DA look-up table
if
self
.
mode
==
'sdd'
:
if
self
.
mode
==
'sdd'
:
da_lut
,
da_num_locks
,
da_width
,
da_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
True
,
device
)
da_lut
,
da_num_locks
,
da_width
,
da_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
True
,
device
)
elif
self
.
mode
==
'dsd'
:
elif
self
.
mode
==
'dsd'
:
da_lut
,
da_num_locks
,
da_width
,
da_packs
=
_sparse_matmul
.
make_sdd_lut
(
layout
,
block
,
dtype
,
device
)
da_lut
,
da_num_locks
,
da_width
,
da_packs
=
_sparse_matmul
.
make_sdd_lut
(
layout
,
block
,
dtype
,
device
)
elif
self
.
mode
==
'dds'
:
elif
self
.
mode
==
'dds'
:
da_lut
,
da_num_locks
,
da_width
,
da_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
not
self
.
trans_b
,
device
)
da_lut
,
da_num_locks
,
da_width
,
da_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
not
self
.
trans_b
,
device
)
# DB look-up table
# DB look-up table
if
self
.
mode
==
'sdd'
:
if
self
.
mode
==
'sdd'
:
db_lut
,
db_num_locks
,
db_width
,
db_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
False
,
device
)
db_lut
,
db_num_locks
,
db_width
,
db_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
False
,
device
)
elif
self
.
mode
==
'dsd'
:
elif
self
.
mode
==
'dsd'
:
db_lut
,
db_num_locks
,
db_width
,
db_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
self
.
trans_a
,
device
)
db_lut
,
db_num_locks
,
db_width
,
db_packs
=
_sparse_matmul
.
make_dxx_lut
(
layout
,
block
,
step
,
self
.
trans_a
,
device
)
elif
self
.
mode
==
'dds'
:
elif
self
.
mode
==
'dds'
:
db_lut
,
db_num_locks
,
db_width
,
db_packs
=
_sparse_matmul
.
make_sdd_lut
(
layout
,
block
,
dtype
,
device
)
db_lut
,
db_num_locks
,
db_width
,
db_packs
=
_sparse_matmul
.
make_sdd_lut
(
layout
,
block
,
dtype
,
device
)
self
.
lut_cache
[
key
]
=
(
c_lut
,
c_num_locks
,
c_width
,
c_packs
,
\
self
.
lut_cache
[
key
]
=
(
c_lut
,
c_num_locks
,
c_width
,
c_packs
,
\
...
@@ -845,11 +698,10 @@ class MatMul:
...
@@ -845,11 +698,10 @@ class MatMul:
assert
layout_dim
in
(
2
,
3
),
"Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
assert
layout_dim
in
(
2
,
3
),
"Layout should be a 2 or 3 dimensional tensor of 0s and 1s"
if
not
mode
==
'sdd'
:
if
not
mode
==
'sdd'
:
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
# Dims to be reduced on the 'inside' of the matmul, either -1 or -2
trans_dense
,
trans_sparse
,
sparse_inner
=
(
trans_b
,
trans_a
,
-
1
)
if
mode
==
'dsd'
else
(
trans_a
,
trans_b
,
-
2
)
trans_dense
,
trans_sparse
,
sparse_inner
=
(
trans_b
,
trans_a
,
-
1
)
if
mode
==
'dsd'
else
(
trans_a
,
trans_b
,
self
.
dense_inner_dim
=
-
(
-
2
)
(
sparse_inner
%
2
)
+
1
)
if
not
trans_dense
else
sparse_inner
self
.
dense_inner_dim
=
-
((
sparse_inner
%
2
)
+
1
)
if
not
trans_dense
else
sparse_inner
sparse_inner
=
sparse_inner
if
not
trans_sparse
else
-
(
sparse_inner
=
sparse_inner
if
not
trans_sparse
else
-
((
sparse_inner
%
2
)
+
1
)
(
sparse_inner
%
2
)
+
1
)
# Inner dim of the dense input should be equal to the inner dim of the sparse input
# Inner dim of the dense input should be equal to the inner dim of the sparse input
self
.
dense_inner_size
=
layout
.
shape
[
sparse_inner
]
*
block
self
.
dense_inner_size
=
layout
.
shape
[
sparse_inner
]
*
block
...
@@ -860,8 +712,7 @@ class MatMul:
...
@@ -860,8 +712,7 @@ class MatMul:
if
layout_dim
==
2
:
if
layout_dim
==
2
:
layout
=
layout
.
unsqueeze
(
0
)
layout
=
layout
.
unsqueeze
(
0
)
layout
=
layout
.
long
(
layout
=
layout
.
long
()
# Above code assumes the layout tensor is an integral type
)
# Above code assumes the layout tensor is an integral type
self
.
spdims
=
layout
.
shape
self
.
spdims
=
layout
.
shape
# timings
# timings
...
@@ -909,31 +760,9 @@ class MatMul:
...
@@ -909,31 +760,9 @@ class MatMul:
b
=
MatMul
.
_pad_shape
(
b
,
self
.
mode
==
'dds'
)
b
=
MatMul
.
_pad_shape
(
b
,
self
.
mode
==
'dds'
)
# execute
# execute
c
=
_sparse_matmul
.
apply
(
a
,
c
=
_sparse_matmul
.
apply
(
a
,
b
,
self
.
trans_a
,
self
.
trans_b
,
False
,
self
.
mode
,
self
.
spdims
,
self
.
block
,
c_lut
,
b
,
c_num_locks
,
c_width
,
c_packs
,
self
.
bench
,
time_c
,
da_lut
,
da_num_locks
,
da_width
,
self
.
trans_a
,
da_packs
,
self
.
bench
,
time_da
,
db_lut
,
db_num_locks
,
db_width
,
db_packs
,
self
.
bench
,
self
.
trans_b
,
False
,
self
.
mode
,
self
.
spdims
,
self
.
block
,
c_lut
,
c_num_locks
,
c_width
,
c_packs
,
self
.
bench
,
time_c
,
da_lut
,
da_num_locks
,
da_width
,
da_packs
,
self
.
bench
,
time_da
,
db_lut
,
db_num_locks
,
db_width
,
db_packs
,
self
.
bench
,
time_db
)
time_db
)
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
...
@@ -948,9 +777,8 @@ class MatMul:
...
@@ -948,9 +777,8 @@ class MatMul:
def
_validate_inputs
(
self
,
a
,
b
):
def
_validate_inputs
(
self
,
a
,
b
):
if
a
.
device
!=
b
.
device
:
if
a
.
device
!=
b
.
device
:
raise
ValueError
(
raise
ValueError
(
f
"Inputs must be on the same device; got
{
a
.
device
}
for tensor A "
f
"Inputs must be on the same device; got
{
a
.
device
}
for tensor A "
f
"and
{
b
.
device
}
for tensor B"
)
f
"and
{
b
.
device
}
for tensor B"
)
if
not
get_accelerator
().
on_accelerator
(
a
):
if
not
get_accelerator
().
on_accelerator
(
a
):
raise
ValueError
(
"Only GPU devices are supported for now"
)
raise
ValueError
(
"Only GPU devices are supported for now"
)
...
@@ -958,9 +786,7 @@ class MatMul:
...
@@ -958,9 +786,7 @@ class MatMul:
if
torch
.
is_autocast_enabled
():
if
torch
.
is_autocast_enabled
():
a
,
b
=
a
.
half
(),
b
.
half
()
a
,
b
=
a
.
half
(),
b
.
half
()
elif
a
.
dtype
!=
b
.
dtype
:
elif
a
.
dtype
!=
b
.
dtype
:
raise
ValueError
(
raise
ValueError
(
f
"Inputs must be the same dtype; got
{
a
.
dtype
}
for A and
{
b
.
dtype
}
for B"
)
f
"Inputs must be the same dtype; got
{
a
.
dtype
}
for A and
{
b
.
dtype
}
for B"
)
mode
,
trans_a
,
trans_b
=
self
.
mode
,
self
.
trans_a
,
self
.
trans_b
mode
,
trans_a
,
trans_b
=
self
.
mode
,
self
.
trans_a
,
self
.
trans_b
if
mode
!=
'sdd'
:
if
mode
!=
'sdd'
:
...
@@ -968,14 +794,12 @@ class MatMul:
...
@@ -968,14 +794,12 @@ class MatMul:
dense
,
dense_name
,
sparse
,
sparse_name
=
(
a
,
'A'
,
b
,
'B'
)
if
mode
==
'dds'
else
(
b
,
'B'
,
a
,
'A'
)
dense
,
dense_name
,
sparse
,
sparse_name
=
(
a
,
'A'
,
b
,
'B'
)
if
mode
==
'dds'
else
(
b
,
'B'
,
a
,
'A'
)
dense_inner
=
dense
.
shape
[
self
.
dense_inner_dim
]
dense_inner
=
dense
.
shape
[
self
.
dense_inner_dim
]
if
dense_inner
!=
self
.
dense_inner_size
:
if
dense_inner
!=
self
.
dense_inner_size
:
raise
ValueError
(
raise
ValueError
(
f
"Expected tensor
{
dense_name
}
to have size
{
self
.
dense_inner_size
}
at dim "
f
"Expected tensor
{
dense_name
}
to have size
{
self
.
dense_inner_size
}
at dim "
f
"
{
self
.
dense_inner_dim
%
dense
.
ndim
}
, got
{
dense_inner
}
."
)
f
"
{
self
.
dense_inner_dim
%
dense
.
ndim
}
, got
{
dense_inner
}
."
)
if
sparse
.
shape
[
-
len
(
self
.
sparse_shape
):]
!=
self
.
sparse_shape
:
if
sparse
.
shape
[
-
len
(
self
.
sparse_shape
):]
!=
self
.
sparse_shape
:
raise
ValueError
(
raise
ValueError
(
f
"Expected tensor with trailing dimensions of shape
{
self
.
sparse_shape
}
for argument "
f
"Expected tensor with trailing dimensions of shape
{
self
.
sparse_shape
}
for argument "
f
"
{
sparse_name
}
, got
{
sparse
.
shape
}
"
)
f
"
{
sparse_name
}
, got
{
sparse
.
shape
}
"
)
def
add_extra_dims
(
x
):
def
add_extra_dims
(
x
):
# Add extra leading singleton dimensions if needed
# Add extra leading singleton dimensions if needed
...
@@ -984,8 +808,7 @@ class MatMul:
...
@@ -984,8 +808,7 @@ class MatMul:
singletons
=
[
1
]
*
dims_needed
singletons
=
[
1
]
*
dims_needed
x
=
x
.
view
(
*
singletons
,
*
x
.
shape
)
x
=
x
.
view
(
*
singletons
,
*
x
.
shape
)
elif
dims_needed
<
0
:
elif
dims_needed
<
0
:
raise
ValueError
(
raise
ValueError
(
"Tensors with more than 4 dimensions are not currently supported"
)
"Tensors with more than 4 dimensions are not currently supported"
)
return
x
return
x
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment