Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
127e81c2
Unverified
Commit
127e81c2
authored
Jun 21, 2023
by
Zach Mueller
Committed by
GitHub
Jun 21, 2023
Browse files
Remove redundant code from TrainingArgs (#24401)
Remove redundant code
parent
cd927a47
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
25 deletions
+13
-25
src/transformers/training_args.py
src/transformers/training_args.py
+13
-25
No files found.
src/transformers/training_args.py
View file @
127e81c2
...
...
@@ -1815,15 +1815,10 @@ class TrainingArguments:
The number of processes used in parallel.
"""
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
():
return
xm
.
xrt_world_size
()
if
self
.
distributed_state
is
not
None
:
return
self
.
distributed_state
.
num_processes
elif
is_sagemaker_mp_enabled
():
return
smp
.
dp_size
()
if
not
smp
.
state
.
cfg
.
prescaled_batch
else
smp
.
rdp_size
()
elif
is_sagemaker_dp_enabled
():
return
dist
.
get_world_size
()
elif
self
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
return
torch
.
distributed
.
get_world_size
()
return
1
@
property
...
...
@@ -1832,14 +1827,10 @@ class TrainingArguments:
The index of the current process used.
"""
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
()
:
return
xm
.
get_ordinal
()
if
self
.
distributed_state
is
not
None
:
return
self
.
distributed_state
.
process_index
elif
is_sagemaker_mp_enabled
():
return
smp
.
dp_rank
()
if
not
smp
.
state
.
cfg
.
prescaled_batch
else
smp
.
rdp_rank
()
elif
is_sagemaker_dp_enabled
():
return
dist
.
get_rank
()
elif
self
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
return
torch
.
distributed
.
get_rank
()
return
0
@
property
...
...
@@ -1848,14 +1839,11 @@ class TrainingArguments:
The index of the local process used.
"""
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
():
return
xm
.
get_local_ordinal
()
if
self
.
distributed_state
is
not
None
:
return
self
.
distributed_state
.
local_process_index
elif
is_sagemaker_mp_enabled
():
return
smp
.
local_rank
()
elif
is_sagemaker_dp_enabled
():
return
dist
.
get_rank
()
elif
self
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
return
self
.
local_rank
return
0
@
property
...
...
@@ -1944,19 +1932,19 @@ class TrainingArguments:
"""
if
is_torch_available
()
and
self
.
world_size
>
1
:
main_process_desc
=
"main process"
if
local
:
is_main_process
=
self
.
local_process_index
==
0
main_process_desc
=
"main local process"
main_process_desc
=
"main local process"
if
local
else
"main process"
if
self
.
distributed_state
is
not
None
:
is_main_process
=
(
self
.
distributed_state
.
is_local_main_process
if
local
else
self
.
distributed_state
.
is_main_process
)
elif
is_sagemaker_mp_enabled
():
is_main_process
=
smp
.
rank
()
==
0
else
:
is_main_process
=
self
.
process_index
==
0
try
:
if
not
is_main_process
:
# tell all replicas to wait
logger
.
debug
(
f
"
{
self
.
process_index
}
: waiting for the
{
main_process_desc
}
to perform
{
desc
}
"
)
if
is_torch_tpu_available
():
xm
.
rendezvous
(
desc
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment