"tests/test_modeling_mobilebert.py" did not exist on "86578bb04c9b34f9d8e35cd4fad42a85910dd9e9"
Unverified Commit 72a6bf33 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Bert, et al] fix early device assignment (#14447)

* fix early device assignment

* more models
parent 83ef8bca
......@@ -219,7 +219,7 @@ class AlbertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
......
......@@ -182,7 +182,7 @@ class BertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
......
......@@ -261,7 +261,7 @@ class BigBirdEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
# End copy
......
......@@ -202,7 +202,7 @@ class ConvBertEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
......
......@@ -172,7 +172,7 @@ class ElectraEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
......
......@@ -120,7 +120,7 @@ class FNetEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
......
......@@ -89,7 +89,7 @@ class RobertaEmbeddings(nn.Module):
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment