Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4f2356dc
Commit
4f2356dc
authored
Mar 07, 2022
by
Lawrence McAfee
Browse files
deleted old distributed_fused_adam
parent
e5bda3c9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
96 deletions
+0
-96
megatron/optimizer/distributed_fused_adam.py
megatron/optimizer/distributed_fused_adam.py
+0
-96
No files found.
megatron/optimizer/distributed_fused_adam.py
deleted
100644 → 0
View file @
e5bda3c9
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
?
?
?
import
math
import
torch
from
megatron
import
mpu
# >>>
from
lutil
import
pax
,
tp
# <<<
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
def
__init__
(
self
,
params
):
super
().
__init__
(
params
,
defaults
=
{})
self
.
initialized
=
False
# self.params_32 = None
# self.grads_32 = None
# self.opt_m = None
# self.opt_v = None
# pax(0, {
# "param_groups" : self.param_groups,
# "param_groups / 0" : self.param_groups[0],
# "param_groups / 1" : self.param_groups[1],
# "param_groups / 0 / params" : self.param_groups[0]["params"],
# # "param_groups / params" : [ g["params"] for g in self.param_groups ],
# })
def
initialize
(
self
):
if
self
.
initialized
:
raise
Exception
(
"initialization worked."
)
return
self
.
initialized
=
True
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
total_param_size
=
sum
(
p
.
numel
()
for
g
in
self
.
param_groups
for
p
in
g
[
"params"
]
)
shard_size
=
int
(
math
.
ceil
(
total_param_size
/
data_parallel_world_size
))
shard_start_index
=
data_parallel_rank
*
shard_size
shard_end_index
=
min
(
total_param_size
,
shard_start_index
+
shard_size
)
shard_size
=
shard_end_index
-
shard_start_index
allocate_shard
=
lambda
dtype
:
torch
.
empty
(
[
shard_size
],
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
self
.
main_param_shard
=
allocate_shard
(
torch
.
float
)
self
.
main_grad_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_m_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_v_shard
=
allocate_shard
(
torch
.
float
)
# pax(2, {
# "data_parallel_rank" : data_parallel_rank,
# "data_parallel_world_size" : data_parallel_world_size,
# "total_param_size" : total_param_size,
# "shard_size" : shard_size,
# "shard" : "%d [ %d, %d ]" % (
# shard_size,
# shard_start_index,
# shard_end_index,
# ),
# })
def
step
(
self
):
self
.
initialize
()
raise
Exception
(
"what's next?"
)
# >>>
# eof
# <<<
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment