Unverified Commit 01c7fb04 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[DeepSpeed] simplify init (#10762)

parent 0486ccdd
......@@ -22,7 +22,6 @@ import os
import re
import tempfile
from pathlib import Path
from types import SimpleNamespace
from .utils import logging
from .utils.versions import require_version
......@@ -430,16 +429,12 @@ def init_deepspeed(trainer, num_training_steps):
"enabled": True,
}
# for clarity extract the specific cl args that are being passed to deepspeed
ds_args = dict(local_rank=args.local_rank)
# keep for quick debug:
# from pprint import pprint; pprint(config)
# init that takes part of the config via `args`, and the bulk of it via `config_params`
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model, optimizer, _, lr_scheduler = deepspeed.initialize(
args=SimpleNamespace(**ds_args), # expects an obj
model=model,
model_parameters=model_parameters,
config_params=config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment