"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8c7267f1cf1881e992244eb245f5927e7b60e3aa"
Unverified Commit 0c7f93f5 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix nn.init.trunc_normal_ call on torch.float16 data (#21789)

fix nn.init.trunc_normal_ call on half data
parent ebf84f07
...@@ -449,17 +449,17 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -449,17 +449,17 @@ class ViTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, ViTEmbeddings): elif isinstance(module, ViTEmbeddings):
nn.init.trunc_normal_( module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings, module.position_embeddings.data.to(torch.float32),
mean=0.0, mean=0.0,
std=self.config.initializer_range, std=self.config.initializer_range,
) ).to(module.position_embeddings.dtype)
nn.init.trunc_normal_( module.cls_token.data = nn.init.trunc_normal_(
module.cls_token, module.cls_token.data.to(torch.float32),
mean=0.0, mean=0.0,
std=self.config.initializer_range, std=self.config.initializer_range,
) ).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
if isinstance(module, ViTEncoder): if isinstance(module, ViTEncoder):
......
...@@ -474,17 +474,17 @@ class ViTHybridPreTrainedModel(PreTrainedModel): ...@@ -474,17 +474,17 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
elif isinstance(module, ViTHybridEmbeddings): elif isinstance(module, ViTHybridEmbeddings):
nn.init.trunc_normal_( module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings, module.position_embeddings.data.to(torch.float32),
mean=0.0, mean=0.0,
std=self.config.initializer_range, std=self.config.initializer_range,
) ).to(module.position_embeddings.dtype)
nn.init.trunc_normal_( module.cls_token.data = nn.init.trunc_normal_(
module.cls_token, module.cls_token.data.to(torch.float32),
mean=0.0, mean=0.0,
std=self.config.initializer_range, std=self.config.initializer_range,
) ).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None:
if isinstance(module, ViTHybridEncoder): if isinstance(module, ViTHybridEncoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment