Commit 56636169 authored by jiaruifang's avatar jiaruifang Committed by Frank Lee
Browse files

polish code

parent d271f259
...@@ -18,8 +18,7 @@ from torch.distributed import ProcessGroup ...@@ -18,8 +18,7 @@ from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor) from ._zero3_utils import (cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor)
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16)
# from ._zero3_utils import cast_float_arguments, cast_tensor_to_fp16
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
...@@ -80,8 +79,7 @@ class ShardedModelV2(nn.Module): ...@@ -80,8 +79,7 @@ class ShardedModelV2(nn.Module):
self._require_backward_grad_sync: bool = True self._require_backward_grad_sync: bool = True
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# TODO args can be Long! args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
# args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
outputs = self.module(*args, **kwargs) outputs = self.module(*args, **kwargs)
return outputs return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment