Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3d61d6b3
Commit
3d61d6b3
authored
Mar 30, 2023
by
qianyj
Browse files
initial files for ResNet50
parent
d3a70caf
Changes
166
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
648 additions
and
0 deletions
+648
-0
orbit/utils/loop_fns.py
orbit/utils/loop_fns.py
+205
-0
orbit/utils/summary_manager.py
orbit/utils/summary_manager.py
+110
-0
orbit/utils/tpu_summaries.py
orbit/utils/tpu_summaries.py
+145
-0
orbit/utils/tpu_summaries_test.py
orbit/utils/tpu_summaries_test.py
+120
-0
scripts-run/single_process.sh
scripts-run/single_process.sh
+34
-0
scripts-run/single_process_xla.sh
scripts-run/single_process_xla.sh
+34
-0
No files found.
orbit/utils/loop_fns.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for creating loop functions."""
from
orbit.utils
import
tpu_summaries
import
tensorflow
as
tf
def
create_loop_fn
(
step_fn
):
"""Creates a loop function driven by a Python `while` loop.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. There are no constraints on the return value of the
function (except that it must be compatible with any `reduce_fn` provided
to the returned `loop_fn`).
Returns:
A loop function taking required `iterator` and `num_steps` parameters, as
well as optional `state` and `reduce_fn` parameters for accumulating state
over multiple iterations of the loop. See the `loop_fn` definition below for
additional details.
"""
def
loop_fn
(
iterator
,
num_steps
,
state
=
None
,
reduce_fn
=
None
):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Additionally, state may be accumulated across iterations of the loop.
Conceptually, state accumulation is handled roughly as follows:
for _ in range(num_steps):
step_outputs = step_fn(iterator)
state = reduce_fn(state, step_outputs)
return state
However, the implementation is slightly more complicated in order to support
looping until the iterator is exhausted (when `num_steps == -1`) and to
properly catch exceptions when running under async remote eager (as is the
case in TPU training setups involving separate coordinator/worker machines).
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps == -1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`, or `None` if `state` and
`reduce_fn` are not provided.
"""
try
:
step
=
0
# To make sure the OutOfRangeError exception can be handled well under
# async remote eager, we need to wrap the loop body in `async_scope`.
with
tf
.
experimental
.
async_scope
():
while
num_steps
==
-
1
or
step
<
num_steps
:
outputs
=
step_fn
(
iterator
)
if
reduce_fn
is
not
None
:
state
=
reduce_fn
(
state
,
outputs
)
step
+=
1
return
state
except
(
StopIteration
,
tf
.
errors
.
OutOfRangeError
):
tf
.
experimental
.
async_clear_error
()
return
state
return
loop_fn
def
create_tf_while_loop_fn
(
step_fn
):
"""Creates a loop function compatible with TF's AutoGraph loop conversion.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator` and `num_steps` parameters. If
called inside a `tf.function`, the loop will be converted by AutoGraph into
a `tf.while_loop` construct. See the `loop_fn` definition below for
additional details.
"""
def
loop_fn
(
iterator
,
num_steps
):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
"""
if
not
isinstance
(
num_steps
,
tf
.
Tensor
):
raise
ValueError
(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`."
)
for
_
in
tf
.
range
(
num_steps
):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with
tf
.
name_scope
(
""
):
step_fn
(
iterator
)
return
loop_fn
def
create_tf_while_loop_fn_with_state
(
step_fn
):
"""Creates a TF while loop function with state.
This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
to be accumulated over multiple iterations of the loop. Note that the
structure of the `state` cannot be changed across iterations.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator`, `num_steps`, `state` and
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
definition below for additional details.
"""
def
loop_fn_with_state
(
iterator
,
num_steps
,
state
,
reduce_fn
):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
state: An initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`.
"""
if
not
isinstance
(
num_steps
,
tf
.
Tensor
):
raise
ValueError
(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`."
)
def
_get_relaxed_tensor_shape
(
t
):
"""Returns a `TensorShape` with all `None` dimensions."""
if
not
tf
.
is_tensor
(
t
):
return
None
shape
=
t
.
shape
if
shape
.
rank
is
not
None
and
shape
.
rank
>
0
:
return
tf
.
TensorShape
([
None
]
*
shape
.
rank
)
return
shape
def
_get_relaxed_shape_structure
(
s
):
"""Returns the relaxed shape of the input nested structure `s`."""
return
tf
.
nest
.
pack_sequence_as
(
state
,
[
_get_relaxed_tensor_shape
(
t
)
for
t
in
tf
.
nest
.
flatten
(
s
)])
for
_
in
tf
.
range
(
num_steps
):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with
tf
.
name_scope
(
""
):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf
.
autograph
.
experimental
.
set_loop_options
(
shape_invariants
=
[(
state
,
_get_relaxed_shape_structure
(
state
))])
outputs
=
step_fn
(
iterator
)
state
=
reduce_fn
(
state
,
outputs
)
return
state
return
loop_fn_with_state
class
LoopFnWithSummaries
(
tpu_summaries
.
OptionalSummariesFunction
):
"""Implements a two-program approach for optimizing summaries on TPU.
This version works with the result of `create_tf_while_loop_fn`.
"""
def
__call__
(
self
,
iterator
,
num_steps
):
if
tf
.
summary
.
should_record_summaries
():
output
=
self
.
with_summaries
(
iterator
,
tf
.
constant
(
1
))
num_steps
-=
1
if
num_steps
>=
1
:
output
=
self
.
without_summaries
(
iterator
,
num_steps
)
return
output
orbit/utils/summary_manager.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a utility class for managing summary writing."""
import
os
import
tensorflow
as
tf
class
SummaryManager
:
"""A utility class for managing summary writing."""
def
__init__
(
self
,
summary_dir
,
summary_fn
,
global_step
=
None
):
"""Initializes the `SummaryManager` instance.
Args:
summary_dir: The directory in which to write summaries. If `None`, all
summary writing operations provided by this class are no-ops.
summary_fn: A callable defined accepting `name`, `value`, and `step`
parameters, making calls to `tf.summary` functions to write summaries.
global_step: A `tf.Variable` containing the global step value.
"""
self
.
_enabled
=
summary_dir
is
not
None
self
.
_summary_dir
=
summary_dir
self
.
_summary_fn
=
summary_fn
self
.
_summary_writers
=
{}
if
global_step
is
None
:
self
.
_global_step
=
tf
.
summary
.
experimental
.
get_step
()
else
:
self
.
_global_step
=
global_step
def
summary_writer
(
self
,
relative_path
=
""
):
"""Returns the underlying summary writer for a specific subdirectory.
Args:
relative_path: The current path in which to write summaries, relative to
the summary directory. By default it is empty, which corresponds to the
root directory.
"""
if
self
.
_summary_writers
and
relative_path
in
self
.
_summary_writers
:
return
self
.
_summary_writers
[
relative_path
]
if
self
.
_enabled
:
self
.
_summary_writers
[
relative_path
]
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
self
.
_summary_dir
,
relative_path
))
else
:
self
.
_summary_writers
[
relative_path
]
=
tf
.
summary
.
create_noop_writer
()
return
self
.
_summary_writers
[
relative_path
]
def
flush
(
self
):
"""Flushes the underlying summary writers."""
if
self
.
_enabled
:
tf
.
nest
.
map_structure
(
tf
.
summary
.
flush
,
self
.
_summary_writers
)
def
write_summaries
(
self
,
summary_dict
):
"""Writes summaries for the given dictionary of values.
This recursively creates subdirectories for any nested dictionaries
provided in `summary_dict`, yielding a hierarchy of directories which will
then be reflected in the TensorBoard UI as different colored curves.
For example, users may evaluate on multiple datasets and return
`summary_dict` as a nested dictionary:
{
"dataset1": {
"loss": loss1,
"accuracy": accuracy1
},
"dataset2": {
"loss": loss2,
"accuracy": accuracy2
},
}
This will create two subdirectories, "dataset1" and "dataset2", inside the
summary root directory. Each directory will contain event files including
both "loss" and "accuracy" summaries.
Args:
summary_dict: A dictionary of values. If any value in `summary_dict` is
itself a dictionary, then the function will create a subdirectory with
name given by the corresponding key. This is performed recursively. Leaf
values are then summarized using the summary writer instance specific to
the parent relative path.
"""
if
not
self
.
_enabled
:
return
self
.
_write_summaries
(
summary_dict
)
def
_write_summaries
(
self
,
summary_dict
,
relative_path
=
""
):
for
name
,
value
in
summary_dict
.
items
():
if
isinstance
(
value
,
dict
):
self
.
_write_summaries
(
value
,
relative_path
=
os
.
path
.
join
(
relative_path
,
name
))
else
:
with
self
.
summary_writer
(
relative_path
).
as_default
():
self
.
_summary_fn
(
name
,
value
,
step
=
self
.
_global_step
)
orbit/utils/tpu_summaries.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains utilities for TPU summary optimization."""
import
contextlib
import
functools
import
tensorflow
as
tf
@
contextlib
.
contextmanager
def
_soft_device_placement
():
"""Context manager for soft device placement, allowing summaries on CPU."""
original_setting
=
tf
.
config
.
get_soft_device_placement
()
try
:
tf
.
config
.
set_soft_device_placement
(
True
)
yield
finally
:
tf
.
config
.
set_soft_device_placement
(
original_setting
)
class
OptionalSummariesFunction
:
"""Wrapper that provides versions of a function with and without summaries.
This is a utility class for implementing optimized summary recording via a
two-function approach, specifically important for TPUs. Two `tf.function`
versions of a given `function` are created: one with soft device placement
enabled (for use on steps that require summary writing), and one with summary
writing and soft device placement entirely disabled (for use on all other
steps). This removes any performance impact of summaries on steps where they
aren't recorded (b/148418718).
This class can be used as a base class to implement summary optimizations for
a function with a specific signature. For example, to implement efficient TPU
summaries for a standard `train()` method (as in `orbit.AbstractTrainer`):
class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction):
'''Implements a two-program approach for summaries on TPU.'''
def __call__(self, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(num_steps)
return output
This can be used directly or to implement a decorator:
def train_function_with_summaries(function=None, **kwargs):
if function is not None:
return TrainFunctionWithSummaries(function, **kwargs)
return functools.partial(TrainFunctionWithSummaries, **kwargs)
The decorator can be applied directly to `train()` methods:
@train_function_with_summaries
def train(self, num_steps):
...
A similar approach approach can be implemented for functions with different
signatures.
Note: The above approach assumes that the frequency of summary writing is
based on a step interval that is divisible by the number of steps executed
in each call to the `train()` function. This is enforced by the
`orbit.Controller`.
This wrapper properly handles instance methods (see `__get__`).
Attributes:
with_summaries: A wrapped version of the underlying function with summaries
enabled (using whatever the active predicate is for
`tf.summary.record_if`), and placed inside a "soft device placement"
context to enable summary recording on TPU.
without_summaries: A wrapped version of the underlying function with all
summary recording disabled.
"""
def
__init__
(
self
,
function
,
**
tf_function_kwargs
):
"""Constructs an instance wrapping the given `function`.
The given `function` is wrapped twice: Once in a "soft device placement"
context (allowing summaries to also run on TPU), and once with summary
recording entirely disabled.
Both of these versions are compiled via `tf.function` (optionally using any
supplied `tf.function` settings), and made available as attributes.
Args:
function: The underlying function to wrap.
**tf_function_kwargs: Additional arguments to pass to `tf.function`.
"""
@
tf
.
function
(
**
tf_function_kwargs
)
@
functools
.
wraps
(
function
)
def
with_summaries
(
*
args
,
**
kwargs
):
with
_soft_device_placement
():
return
function
(
*
args
,
**
kwargs
)
@
tf
.
function
(
**
tf_function_kwargs
)
@
functools
.
wraps
(
function
)
def
without_summaries
(
*
args
,
**
kwargs
):
with
tf
.
summary
.
record_if
(
False
):
return
function
(
*
args
,
**
kwargs
)
self
.
with_summaries
=
with_summaries
self
.
without_summaries
=
without_summaries
def
__get__
(
self
,
instance
,
owner
):
"""Allows this class to be used to wrap methods as well as free functions.
For `tf.function` to work properly in all cases (e.g., when an
input_signature is specified), any `tf.function`-converted methods must be
properly bound to an instance if they are called as an instance method.
This is done by implementing this `__get__` method of the descriptor
protocol, and forwarding to the `__get__` method on the underlying
`tf.function`s.
Args:
instance: The instance to bind to.
owner: The class type of the instance.
Returns:
A new bound instance of `TpuDiscretionarySummariesFunctions`.
"""
new
=
object
.
__new__
(
self
.
__class__
)
# pytype: disable=attribute-error # See b/162476201.
new
.
with_summaries
=
self
.
with_summaries
.
__get__
(
instance
,
owner
)
new
.
without_summaries
=
self
.
without_summaries
.
__get__
(
instance
,
owner
)
# pytype: enable=attribute-error
return
new
orbit/utils/tpu_summaries_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.utils.tpu_summaries."""
import
functools
import
os
from
orbit.utils
import
common
from
orbit.utils
import
tpu_summaries
import
tensorflow
as
tf
class
TrainFunctionWithSummaries
(
tpu_summaries
.
OptionalSummariesFunction
):
"""Implements a two-program approach for summaries on TPU."""
def
__call__
(
self
,
num_steps
):
if
tf
.
summary
.
should_record_summaries
():
output
=
self
.
with_summaries
(
tf
.
constant
(
1
))
num_steps
-=
1
if
num_steps
>=
1
:
output
=
self
.
without_summaries
(
num_steps
)
return
output
def
train_function_with_summaries
(
function
=
None
,
**
kwargs
):
if
function
is
not
None
:
return
TrainFunctionWithSummaries
(
function
,
**
kwargs
)
return
functools
.
partial
(
TrainFunctionWithSummaries
,
**
kwargs
)
class
DummyTrainer
(
tf
.
Module
):
def
__init__
(
self
):
self
.
step_counter
=
common
.
create_global_step
()
@
train_function_with_summaries
def
train_with_tpu_summary_optimization
(
self
,
num_steps
):
for
_
in
tf
.
range
(
num_steps
):
tf
.
summary
.
scalar
(
"step"
,
self
.
step_counter
,
step
=
self
.
step_counter
)
self
.
step_counter
.
assign_add
(
1
)
return
self
.
step_counter
@
train_function_with_summaries
(
input_signature
=
[
tf
.
TensorSpec
((),
dtype
=
tf
.
int32
)])
def
train_with_tpu_summary_optimization_and_input_signature
(
self
,
num_steps
):
for
_
in
tf
.
range
(
num_steps
):
tf
.
summary
.
scalar
(
"step"
,
self
.
step_counter
,
step
=
self
.
step_counter
)
self
.
step_counter
.
assign_add
(
1
)
return
self
.
step_counter
def
train_with_tpu_summary_optimization_no_decorator
(
self
,
num_steps
):
for
_
in
tf
.
range
(
num_steps
):
tf
.
summary
.
scalar
(
"step"
,
self
.
step_counter
,
step
=
self
.
step_counter
)
self
.
step_counter
.
assign_add
(
1
)
return
self
.
step_counter
class
TpuSummariesTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
trainer
=
DummyTrainer
()
def
_get_events_from_logdir
(
self
,
logdir
):
event_files
=
tf
.
io
.
gfile
.
listdir
(
logdir
)
self
.
assertLen
(
event_files
,
1
)
path
=
os
.
path
.
join
(
logdir
,
event_files
[
0
])
events
=
list
(
tf
.
compat
.
v1
.
train
.
summary_iterator
(
path
))
return
[
event
for
event
in
events
if
event
.
WhichOneof
(
"what"
)
==
"summary"
]
def
_validate_tpu_summary_optimization
(
self
,
function
,
*
args
,
**
kwargs
):
logdir
=
self
.
get_temp_dir
()
with
tf
.
summary
.
create_file_writer
(
logdir
).
as_default
():
with
tf
.
summary
.
record_if
(
lambda
:
self
.
trainer
.
step_counter
%
20
==
0
):
for
_
in
range
(
4
):
output
=
function
(
tf
.
constant
(
10
),
*
args
,
**
kwargs
)
events
=
self
.
_get_events_from_logdir
(
logdir
)
self
.
assertLen
(
events
,
2
)
self
.
assertEqual
(
events
[
0
].
step
,
0
)
self
.
assertEqual
(
events
[
1
].
step
,
20
)
return
output
def
test_train_with_tpu_summary_optimization
(
self
):
output
=
self
.
_validate_tpu_summary_optimization
(
self
.
trainer
.
train_with_tpu_summary_optimization
)
self
.
assertEqual
(
output
,
self
.
trainer
.
step_counter
.
numpy
())
def
test_train_with_tpu_summary_optimization_no_decorator
(
self
):
optimized
=
train_function_with_summaries
(
self
.
trainer
.
train_with_tpu_summary_optimization_no_decorator
)
output
=
self
.
_validate_tpu_summary_optimization
(
optimized
)
self
.
assertEqual
(
output
,
self
.
trainer
.
step_counter
.
numpy
())
def
test_train_with_tpu_summary_optimization_and_input_signature
(
self
):
output
=
self
.
_validate_tpu_summary_optimization
(
self
.
trainer
.
train_with_tpu_summary_optimization_and_input_signature
)
self
.
assertEqual
(
output
,
self
.
trainer
.
step_counter
.
numpy
())
function
=
self
.
trainer
.
train_with_tpu_summary_optimization_and_input_signature
expected
=
(
tf
.
TensorSpec
((),
dtype
=
tf
.
int32
),)
input_signature
=
function
.
with_summaries
.
input_signature
self
.
assertEqual
(
input_signature
,
expected
)
input_signature
=
function
.
without_summaries
.
input_signature
self
.
assertEqual
(
input_signature
,
expected
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
scripts-run/single_process.sh
0 → 100644
View file @
3d61d6b3
#!/bin/bash
lrank
=
$OMPI_COMM_WORLD_LOCAL_RANK
drank
=
$OMPI_COMM_WORLD_RANK
APP
=
"python3 ./official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --num_gpus=1 --skip_eval=true --batch_size=512 --train_epochs=90 --use_synthetic_data=false --distribution_strategy=multi_worker_mirrored --all_reduce_alg=nccl --dtype=fp32 --data_dir=
${
data_dir
}
--task_index=
${
drank
}
"
case
${
lrank
}
in
[
0]
)
export
HIP_VISIBLE_DEVICES
=
0
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
;;
[
1]
)
export
HIP_VISIBLE_DEVICES
=
1
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
numactl
--cpunodebind
=
1
--membind
=
1
${
APP
}
;;
[
2]
)
export
HIP_VISIBLE_DEVICES
=
2
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
numactl
--cpunodebind
=
2
--membind
=
2
${
APP
}
;;
[
3]
)
export
HIP_VISIBLE_DEVICES
=
3
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
numactl
--cpunodebind
=
3
--membind
=
3
${
APP
}
;;
esac
scripts-run/single_process_xla.sh
0 → 100644
View file @
3d61d6b3
#!/bin/bash
lrank
=
$OMPI_COMM_WORLD_LOCAL_RANK
drank
=
$OMPI_COMM_WORLD_RANK
APP
=
"python3 ./official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --num_gpus=1 --skip_eval=true --batch_size=512 --train_epochs=90 --use_synthetic_data=false --distribution_strategy=multi_worker_mirrored --all_reduce_alg=nccl --dtype=fp32 --data_dir=
${
data_dir
}
--task_index=
${
drank
}
"
case
${
lrank
}
in
[
0]
)
export
HIP_VISIBLE_DEVICES
=
0
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
TF_XLA_FLAGS
=
"--tf_xla_auto_jit=2"
numactl
--cpunodebind
=
0
--membind
=
0
${
APP
}
;;
[
1]
)
export
HIP_VISIBLE_DEVICES
=
1
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
TF_XLA_FLAGS
=
"--tf_xla_auto_jit=2"
numactl
--cpunodebind
=
1
--membind
=
1
${
APP
}
;;
[
2]
)
export
HIP_VISIBLE_DEVICES
=
2
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
TF_XLA_FLAGS
=
"--tf_xla_auto_jit=2"
numactl
--cpunodebind
=
2
--membind
=
2
${
APP
}
;;
[
3]
)
export
HIP_VISIBLE_DEVICES
=
3
export
UCX_NET_DEVICES
=
mlx5_0:1
export
UCX_IB_PCI_BW
=
mlx5_0:50Gbs
TF_XLA_FLAGS
=
"--tf_xla_auto_jit=2"
numactl
--cpunodebind
=
3
--membind
=
3
${
APP
}
;;
esac
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment