Unverified Commit b08aa6f3 authored by Jon Eyolfson's avatar Jon Eyolfson Committed by GitHub
Browse files

Improve starred expressions (#696)

* Improve starred expressions

`deepspeed/profiling/flops_profiler/profiler.py` uses starred expressions
that are no longer valid with [PEP 617][1]. The new Python parser is in 3.9,
and this change allows DeepSpeed to run with the newest Python version. I have
not checked all locations that has this issue. However, this change allows me
to run simple examples.

[1]: https://www.python.org/dev/peps/pep-0617/

* Match style for "Improve starred expressions", although readability suffers

The style guide might need to be updated for this new use case of expressions.
Python [Issue 40631][1] includes more discussion on the change.

[1]: https://bugs.python.org/issue40631

Co-authored-by: default avatarCheng Li <pistasable@gmail.com>
parent 4f1d827c
...@@ -727,12 +727,13 @@ def get_model_profile( ...@@ -727,12 +727,13 @@ def get_model_profile(
else: else:
try: try:
batch = torch.ones(()).new_empty( batch = torch.ones(()).new_empty(
(*input_res), (*input_res,
),
dtype=next(model.parameters()).dtype, dtype=next(model.parameters()).dtype,
device=next(model.parameters()).device, device=next(model.parameters()).device,
) )
except StopIteration: except StopIteration:
batch = torch.ones(()).new_empty((*input_res)) batch = torch.ones(()).new_empty((*input_res, ))
_ = model(batch) _ = model(batch)
prof.start_profile(ignore_list=ignore_modules) prof.start_profile(ignore_list=ignore_modules)
...@@ -744,12 +745,13 @@ def get_model_profile( ...@@ -744,12 +745,13 @@ def get_model_profile(
else: else:
try: try:
batch = torch.ones(()).new_empty( batch = torch.ones(()).new_empty(
(*input_res), (*input_res,
),
dtype=next(model.parameters()).dtype, dtype=next(model.parameters()).dtype,
device=next(model.parameters()).device, device=next(model.parameters()).device,
) )
except StopIteration: except StopIteration:
batch = torch.ones(()).new_empty((*input_res)) batch = torch.ones(()).new_empty((*input_res, ))
_ = model(batch) _ = model(batch)
flops = prof.get_total_flops() flops = prof.get_total_flops()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment