Commit 76928caa authored by Neel Kant's avatar Neel Kant
Browse files

Create tensors on cuda rather than copying

parent 2a3b445d
...@@ -52,7 +52,7 @@ class ICTBertModel(MegatronModule): ...@@ -52,7 +52,7 @@ class ICTBertModel(MegatronModule):
def embed_query(self, query_tokens, query_attention_mask): def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model""" """Embed a batch of tokens using the query model"""
if self.use_query_model: if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda() query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits return query_ict_logits
else: else:
...@@ -61,7 +61,7 @@ class ICTBertModel(MegatronModule): ...@@ -61,7 +61,7 @@ class ICTBertModel(MegatronModule):
def embed_block(self, block_tokens, block_attention_mask): def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model""" """Embed a batch of tokens using the block model"""
if self.use_block_model: if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda() block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits return block_ict_logits
else: else:
......
...@@ -99,8 +99,8 @@ def forward_step(data_iterator, model): ...@@ -99,8 +99,8 @@ def forward_step(data_iterator, model):
global_batch_size = int(batch_size * data_parallel_size) global_batch_size = int(batch_size * data_parallel_size)
all_logits_shape = (int(global_batch_size), int(query_logits.shape[1])) all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda() all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0)
all_block_logits = all_query_logits.clone().cuda() all_block_logits = all_query_logits.clone()
# record this processes' data # record this processes' data
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment