Unverified Commit 72a6bf33 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Bert, et al] fix early device assignment (#14447)

* fix early device assignment

* more models
parent 83ef8bca
...@@ -219,7 +219,7 @@ class AlbertEmbeddings(nn.Module): ...@@ -219,7 +219,7 @@ class AlbertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"): if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False, persistent=False,
) )
......
...@@ -182,7 +182,7 @@ class BertEmbeddings(nn.Module): ...@@ -182,7 +182,7 @@ class BertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"): if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False, persistent=False,
) )
......
...@@ -261,7 +261,7 @@ class BigBirdEmbeddings(nn.Module): ...@@ -261,7 +261,7 @@ class BigBirdEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"): if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False, persistent=False,
) )
# End copy # End copy
......
...@@ -202,7 +202,7 @@ class ConvBertEmbeddings(nn.Module): ...@@ -202,7 +202,7 @@ class ConvBertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"): if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False, persistent=False,
) )
......
...@@ -172,7 +172,7 @@ class ElectraEmbeddings(nn.Module): ...@@ -172,7 +172,7 @@ class ElectraEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"): if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False, persistent=False,
) )
......
...@@ -120,7 +120,7 @@ class FNetEmbeddings(nn.Module): ...@@ -120,7 +120,7 @@ class FNetEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"): if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False, persistent=False,
) )
......
...@@ -89,7 +89,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -89,7 +89,7 @@ class RobertaEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"): if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False, persistent=False,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment