Unverified Commit 318fbf11 authored by Kirigaya Kazuto's avatar Kirigaya Kazuto Committed by GitHub
Browse files

[NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style (#1559)

parent b0f4c0bd
...@@ -778,4 +778,4 @@ class OneFOneBPipelineEngine(PipelineEngineBase): ...@@ -778,4 +778,4 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
criterion: Callable = None, criterion: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
use_1F1B = True use_1F1B = True
super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint) super().__init__(module_partitions, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, checkpoint)
\ No newline at end of file
...@@ -26,13 +26,9 @@ class MultiTensorApply(object): ...@@ -26,13 +26,9 @@ class MultiTensorApply(object):
raise RuntimeError( raise RuntimeError(
"Attempted to call MultiTensorApply method, but MultiTensorApply " "Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without " "is not available, possibly because Apex was installed without "
"--cpp_ext --cuda_ext. Original import error message:", "--cpp_ext --cuda_ext. Original import error message:", MultiTensorApply.import_err)
MultiTensorApply.import_err)
def __call__(self, op, noop_flag_buffer, tensor_lists, *args): def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
self.check_avail() self.check_avail()
return op(self.chunk_size, return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
noop_flag_buffer,
tensor_lists,
*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment