Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
127e81c2
"docs/vscode:/vscode.git/clone" did not exist on "6c08840628a22a4d53ae563d1041479649d1a8e7"
Unverified
Commit
127e81c2
authored
Jun 21, 2023
by
Zach Mueller
Committed by
GitHub
Jun 21, 2023
Browse files
Remove redundant code from TrainingArgs (#24401)
Remove redundant code
parent
cd927a47
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
25 deletions
+13
-25
src/transformers/training_args.py
src/transformers/training_args.py
+13
-25
No files found.
src/transformers/training_args.py
View file @
127e81c2
...
@@ -1815,15 +1815,10 @@ class TrainingArguments:
...
@@ -1815,15 +1815,10 @@ class TrainingArguments:
The number of processes used in parallel.
The number of processes used in parallel.
"""
"""
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
if
self
.
distributed_state
is
not
None
:
if
is_torch_tpu_available
():
return
self
.
distributed_state
.
num_processes
return
xm
.
xrt_world_size
()
elif
is_sagemaker_mp_enabled
():
elif
is_sagemaker_mp_enabled
():
return
smp
.
dp_size
()
if
not
smp
.
state
.
cfg
.
prescaled_batch
else
smp
.
rdp_size
()
return
smp
.
dp_size
()
if
not
smp
.
state
.
cfg
.
prescaled_batch
else
smp
.
rdp_size
()
elif
is_sagemaker_dp_enabled
():
return
dist
.
get_world_size
()
elif
self
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
return
torch
.
distributed
.
get_world_size
()
return
1
return
1
@
property
@
property
...
@@ -1832,14 +1827,10 @@ class TrainingArguments:
...
@@ -1832,14 +1827,10 @@ class TrainingArguments:
The index of the current process used.
The index of the current process used.
"""
"""
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
()
:
if
self
.
distributed_state
is
not
None
:
return
xm
.
get_ordinal
()
return
self
.
distributed_state
.
process_index
elif
is_sagemaker_mp_enabled
():
elif
is_sagemaker_mp_enabled
():
return
smp
.
dp_rank
()
if
not
smp
.
state
.
cfg
.
prescaled_batch
else
smp
.
rdp_rank
()
return
smp
.
dp_rank
()
if
not
smp
.
state
.
cfg
.
prescaled_batch
else
smp
.
rdp_rank
()
elif
is_sagemaker_dp_enabled
():
return
dist
.
get_rank
()
elif
self
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
return
torch
.
distributed
.
get_rank
()
return
0
return
0
@
property
@
property
...
@@ -1848,14 +1839,11 @@ class TrainingArguments:
...
@@ -1848,14 +1839,11 @@ class TrainingArguments:
The index of the local process used.
The index of the local process used.
"""
"""
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
if
is_torch_tpu_available
():
return
xm
.
get_local_ordinal
()
if
self
.
distributed_state
is
not
None
:
return
self
.
distributed_state
.
local_process_index
elif
is_sagemaker_mp_enabled
():
elif
is_sagemaker_mp_enabled
():
return
smp
.
local_rank
()
return
smp
.
local_rank
()
elif
is_sagemaker_dp_enabled
():
return
dist
.
get_rank
()
elif
self
.
parallel_mode
==
ParallelMode
.
DISTRIBUTED
:
return
self
.
local_rank
return
0
return
0
@
property
@
property
...
@@ -1944,19 +1932,19 @@ class TrainingArguments:
...
@@ -1944,19 +1932,19 @@ class TrainingArguments:
"""
"""
if
is_torch_available
()
and
self
.
world_size
>
1
:
if
is_torch_available
()
and
self
.
world_size
>
1
:
main_process_desc
=
"main process"
main_process_desc
=
"main local process"
if
local
else
"main process"
if
local
:
if
self
.
distributed_state
is
not
None
:
is_main_process
=
self
.
local_process_index
==
0
is_main_process
=
(
main_process_desc
=
"main local process"
self
.
distributed_state
.
is_local_main_process
if
local
else
self
.
distributed_state
.
is_main_process
)
elif
is_sagemaker_mp_enabled
():
elif
is_sagemaker_mp_enabled
():
is_main_process
=
smp
.
rank
()
==
0
is_main_process
=
smp
.
rank
()
==
0
else
:
is_main_process
=
self
.
process_index
==
0
try
:
try
:
if
not
is_main_process
:
if
not
is_main_process
:
# tell all replicas to wait
# tell all replicas to wait
logger
.
debug
(
f
"
{
self
.
process_index
}
: waiting for the
{
main_process_desc
}
to perform
{
desc
}
"
)
logger
.
debug
(
f
"
{
self
.
process_index
}
: waiting for the
{
main_process_desc
}
to perform
{
desc
}
"
)
if
is_torch_tpu_available
():
if
is_torch_tpu_available
():
xm
.
rendezvous
(
desc
)
xm
.
rendezvous
(
desc
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment