Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
318fbf11
Unverified
Commit
318fbf11
authored
Sep 08, 2022
by
Kirigaya Kazuto
Committed by
GitHub
Sep 08, 2022
Browse files
[NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style (#1559)
parent
b0f4c0bd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
7 deletions
+3
-7
colossalai/pipeline/rpc/PipelineBase.py
colossalai/pipeline/rpc/PipelineBase.py
+1
-1
colossalai/utils/multi_tensor_apply/multi_tensor_apply.py
colossalai/utils/multi_tensor_apply/multi_tensor_apply.py
+2
-6
No files found.
colossalai/pipeline/rpc/PipelineBase.py
View file @
318fbf11
...
@@ -778,4 +778,4 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
...
@@ -778,4 +778,4 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
criterion
:
Callable
=
None
,
criterion
:
Callable
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
checkpoint
:
bool
=
False
)
->
None
:
use_1F1B
=
True
use_1F1B
=
True
super
().
__init__
(
module_partitions
,
stage_num
,
num_microbatches
,
device
,
use_1F1B
,
chunk
,
criterion
,
checkpoint
)
super
().
__init__
(
module_partitions
,
stage_num
,
num_microbatches
,
device
,
use_1F1B
,
chunk
,
criterion
,
checkpoint
)
\ No newline at end of file
colossalai/utils/multi_tensor_apply/multi_tensor_apply.py
View file @
318fbf11
...
@@ -26,13 +26,9 @@ class MultiTensorApply(object):
...
@@ -26,13 +26,9 @@ class MultiTensorApply(object):
raise
RuntimeError
(
raise
RuntimeError
(
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without "
"is not available, possibly because Apex was installed without "
"--cpp_ext --cuda_ext. Original import error message:"
,
"--cpp_ext --cuda_ext. Original import error message:"
,
MultiTensorApply
.
import_err
)
MultiTensorApply
.
import_err
)
def
__call__
(
self
,
op
,
noop_flag_buffer
,
tensor_lists
,
*
args
):
def
__call__
(
self
,
op
,
noop_flag_buffer
,
tensor_lists
,
*
args
):
self
.
check_avail
()
self
.
check_avail
()
return
op
(
self
.
chunk_size
,
return
op
(
self
.
chunk_size
,
noop_flag_buffer
,
tensor_lists
,
*
args
)
noop_flag_buffer
,
tensor_lists
,
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment