Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e46230dc
Commit
e46230dc
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
moved 'reduce_grads()' to MegatronOptimizer.
parent
772a4a2d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
154 additions
and
227 deletions
+154
-227
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+154
-227
No files found.
megatron/optimizer/optimizer.py
View file @
e46230dc
...
@@ -124,21 +124,6 @@ class MegatronOptimizer(ABC):
...
@@ -124,21 +124,6 @@ class MegatronOptimizer(ABC):
return
self
.
get_loss_scale
()
*
loss
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
def
reduce_grads
(
self
):
pass
@
abstractmethod
def
step
(
self
):
pass
@
abstractmethod
def
gather_params
(
self
):
pass
@
abstractmethod
@
abstractmethod
def
reload_model_params
(
self
):
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
"""Refreshes any internal state from the current model parameters.
...
@@ -182,6 +167,80 @@ class MegatronOptimizer(ABC):
...
@@ -182,6 +167,80 @@ class MegatronOptimizer(ABC):
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
@
abstractmethod
def
step
(
self
):
pass
def
gather_params
(
self
):
pass
def
reduce_grads
(
self
,
model
):
# >>>
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.utils
import
unwrap_model
args
=
get_args
()
timers
=
get_timers
()
# <<<
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# >>>
# raise Exception("[main] ready for weight sync?")
# <<<
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
# >>>
raise
Exception
(
"[main] ready for t5 sync?"
)
# <<<
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
# class BaseFloat16Optimizer(MegatronOptimizer):
# class BaseFloat16Optimizer(MegatronOptimizer):
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
...
@@ -251,15 +310,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -251,15 +310,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
# Update across all model parallel instances.
if
args
.
use_
# >>>
# torch.distributed.all_reduce(self.found_inf,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
# +++
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
get_model_parallel_group
())
group
=
self
.
get_model_parallel_group
())
# <<<
# Check for nan.
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
...
@@ -267,58 +320,58 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -267,58 +320,58 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return
found_inf_flag
return
found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@
classmethod
#
@classmethod
def
debug_base
(
cls
,
ITERATION
,
key
,
value
):
#
def debug_base(cls, ITERATION, key, value):
from
megatron
import
get_args
#
from megatron import get_args
args
=
get_args
()
#
args = get_args()
my_rank
=
torch
.
distributed
.
get_rank
()
#
my_rank = torch.distributed.get_rank()
if
ITERATION
!=
DEBUG_ITERATION
:
#
if ITERATION != DEBUG_ITERATION:
return
#
return
for
r
in
range
(
torch
.
distributed
.
get_world_size
()):
#
for r in range(torch.distributed.get_world_size()):
if
my_rank
==
r
:
#
if my_rank == r:
print
(
" + br/%s; [r%d, i%d]; %s, %.12e"
%
(
"fix "
if
args
.
use_distributed_optimizer
else
"main"
,
my_rank
,
ITERATION
,
key
,
value
))
#
print(" + br/%s; [r%d, i%d]; %s, %.12e" % ("fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
torch
.
distributed
.
barrier
()
#
torch.distributed.barrier()
torch
.
distributed
.
barrier
()
#
torch.distributed.barrier()
# if my_rank == 0:
#
# if my_rank == 0:
# raise Exception("debug.")
#
# raise Exception("debug.")
# else:
#
# else:
# exit(0)
#
# exit(0)
exit
(
0
)
#
exit(0)
def
debug_model
(
self
,
ITERATION
,
key
,
use_grad
):
#
def debug_model(self, ITERATION, key, use_grad):
use_grad
=
bool
(
use_grad
)
#
use_grad = bool(use_grad)
tensors
=
[
#
tensors = [
(
p
.
main_grad
.
float
()
if
use_grad
else
p
.
float
())
#
(p.main_grad.float() if use_grad else p.float())
for
m
in
self
.
models
for
p
in
m
.
parameters
()
#
for m in self.models for p in m.parameters()
]
#
]
count
=
sum
(
t
.
nelement
()
for
t
in
tensors
)
#
count = sum(t.nelement() for t in tensors)
return
self
.
debug_base
(
#
return self.debug_base(
ITERATION
,
#
ITERATION,
"model/%s, %s [count %d]"
%
(
#
"model/%s, %s [count %d]" % (
"grad"
if
use_grad
else
"param"
,
#
"grad" if use_grad else "param",
key
,
#
key,
count
,
#
count,
),
#
),
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
#
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum
(
torch
.
sum
(
torch
.
abs
(
t
))
for
t
in
tensors
),
#
sum(torch.sum(torch.abs(t)) for t in tensors),
)
#
)
def
debug_main
(
self
,
ITERATION
,
key
,
use_grad
):
#
def debug_main(self, ITERATION, key, use_grad):
use_grad
=
bool
(
use_grad
)
#
use_grad = bool(use_grad)
tensors
=
[
#
tensors = [
p
.
grad
if
use_grad
else
p
#
p.grad if use_grad else p
for
g
in
self
.
optimizer
.
param_groups
#
for g in self.optimizer.param_groups
for
p
in
g
[
"params"
]
#
for p in g["params"]
]
#
]
tensors
=
[
t
.
float
()
for
t
in
tensors
]
#
tensors = [ t.float() for t in tensors ]
count
=
sum
(
t
.
nelement
()
for
t
in
tensors
)
#
count = sum(t.nelement() for t in tensors)
return
self
.
debug_base
(
#
return self.debug_base(
ITERATION
,
#
ITERATION,
"main/%s, %s [count %d]"
%
(
#
"main/%s, %s [count %d]" % (
"grad"
if
use_grad
else
"param"
,
#
"grad" if use_grad else "param",
key
,
#
key,
count
,
#
count,
),
#
),
sum
(
torch
.
sum
(
torch
.
abs
(
t
))
for
t
in
tensors
),
#
sum(torch.sum(torch.abs(t)) for t in tensors),
)
#
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -327,10 +380,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -327,10 +380,8 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers
=
get_timers
()
timers
=
get_timers
()
# >>>
# >>>
# self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_model_grad(ITERATION, "before copy grad.")
# self.debug_main(ITERATION, "before copy grad.", 0)
# self.debug_main_param(ITERATION, "before copy grad.")
# self.debug_main_grad(ITERATION, "before copy grad.")
# <<<
# <<<
# Copy gradients from model params to main params.
# Copy gradients from model params to main params.
...
@@ -338,11 +389,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -338,11 +389,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
(
ITERATION
)
self
.
_copy_model_grads_to_main_grads
(
ITERATION
)
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# >>>
# self.debug_model(ITERATION, "after copy grad.", 0)
# self.debug_main(ITERATION, "after copy grad.", 1)
# <<<
# Do unscale, check for inf, and update grad scaler only for
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
if
self
.
grad_scaler
:
...
@@ -358,11 +404,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -358,11 +404,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# If we found inf/nan, skip the update.
# If we found inf/nan, skip the update.
if
found_inf_flag
:
if
found_inf_flag
:
pax
(
0
,
{
"main params"
:
self
.
get_main_params
(),
"main grads"
:
self
.
get_main_grads
(),
"found_inf_flag"
:
found_inf_flag
,
})
return
False
,
None
,
None
return
False
,
None
,
None
# Clip the main gradients.
# Clip the main gradients.
...
@@ -376,41 +417,21 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -376,41 +417,21 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
self
.
log_num_zeros_in_grad
else
None
# >>>
# param = self.optimizer.param_groups[0]["params"][0]
# pax(0, {
# "param" : tp(param),
# "grad" : tp(param.grad),
# })
# <<<
# >>>
# self.debug_main(ITERATION, "before step.", 0)
# <<<
# Step the optimizer.
# Step the optimizer.
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
# >>>
# self.debug_main(ITERATION, "after step.", 0)
# <<<
# Update params from main params.
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
self
.
_copy_main_params_to_model_params
(
ITERATION
)
self
.
_copy_main_params_to_model_params
(
ITERATION
)
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# >>>
# self.debug_main_param(ITERATION, "after copy param.")
# self.debug_main_grad(ITERATION, "after copy param.")
# <<<
# Successful update.
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
return
True
,
grad_norm
,
num_zeros_in_grad
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
class
Float16OptimizerWithFloat16Params
(
BaseFloat16Optimizer
):
# class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
class
Float16OptimizerWithFloat16Params
(
MixedPrecisionOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
Arguments:
...
@@ -482,17 +503,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -482,17 +503,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
fp32_from_float16_params_this_group
.
append
(
main_param
)
fp32_from_float16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
if
param
in
self
.
optimizer
.
state
:
# >>>
raise
Exception
(
"hi."
)
# <<<
self
.
optimizer
.
state
[
main_param
]
\
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
# >>>
pax
(
0
,
{
"param"
:
param
})
# <<<
fp32_params_this_group
.
append
(
param
)
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
param_group
[
'params'
][
i
]
=
param
...
@@ -512,19 +527,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -512,19 +527,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# params = self.get_parameters()
# pax(0, {
# # "params / 0" : params[0],
# "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
# "grads" : [ (param_is_not_tensor_parallel_duplicate(p.grad), tp(p.grad)) for p in params ],
# })
# <<<
def
get_model_parallel_group
(
self
):
def
get_model_parallel_group
(
self
):
return
mpu
.
get_model_parallel_group
()
)
return
mpu
.
get_model_parallel_group
()
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
...
@@ -541,76 +546,35 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -541,76 +546,35 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
# >>>
def
_collect_main_grad_data_for_unscaling
(
self
):
def
reduce_grads
(
self
,
model
):
# >>>
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.utils
import
unwrap_model
args
=
get_args
()
main_grads
=
[]
timers
=
get_timers
()
# <<<
# All-reduce if needed.
# fp32 params from float16 ones.
if
args
.
DDP_impl
==
'local'
:
for
main_group
in
self
.
fp32_from_float16_groups
:
timers
(
'backward-params-all-reduce'
).
start
()
for
main_param
in
main_group
:
for
model_module
in
model
:
if
main_param
.
grad
is
not
None
:
model_module
.
allreduce_gradients
()
main_grads
.
append
(
main_param
.
grad
.
data
)
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# Append fp32 parameters.
# that word_embeddings parameters stay in sync.
for
main_group
in
self
.
fp32_from_fp32_groups
:
# This should only run for models that support pipelined model parallelism
for
main_param
in
main_group
:
# (BERT and GPT-2).
if
main_param
.
grad
is
not
None
:
timers
(
'backward-embedding-all-reduce'
).
start
()
main_grads
.
append
(
main_param
.
grad
.
data
)
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
return
main_grads
# >>>
# raise Exception("[main] ready for weight sync?")
# <<<
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
def
_get_model_and_main_params_data_float16
(
self
):
# stages to ensure that position embeddings parameters stay in sync.
model_data
=
[]
# This should only run for T5 models with pipeline parallelism
main_data
=
[]
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
self
.
fp32_from_float16_groups
):
args
.
pipeline_model_parallel_split_rank
is
not
None
:
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
# >>>
model_data
.
append
(
model_param
.
data
)
raise
Exception
(
"[main] ready for t5 sync?"
)
main_data
.
append
(
main_param
.
data
)
# <<<
return
model_data
,
main_data
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
def
gather_params
(
self
,
ITERATION
):
pass
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
# This only needs to be done for the float16 group.
# This only needs to be done for the float16 group.
...
@@ -653,49 +617,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -653,49 +617,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# })
# })
# <<<
# <<<
def
_collect_main_grad_data_for_unscaling
(
self
):
main_grads
=
[]
# fp32 params from float16 ones.
for
main_group
in
self
.
fp32_from_float16_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Append fp32 parameters.
for
main_group
in
self
.
fp32_from_fp32_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
return
main_grads
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_main_params_to_model_params
(
self
,
ITERATION
):
def
_copy_main_params_to_model_params
(
self
,
ITERATION
):
# Only needed for the float16 params.
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
overflow_buf
=
self
.
_dummy_overflow_buf
)
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** main. **",
# "ITERATION" : ITERATION,
# "model params" : [p for m in self.models for p in m.parameters()],
# })
# <<<
def
_copy_model_params_to_main_params
(
self
):
def
_copy_model_params_to_main_params
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment