Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
76928caa
Commit
76928caa
authored
Jun 29, 2020
by
Neel Kant
Browse files
Create tensors on cuda rather than copying
parent
2a3b445d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
megatron/model/realm_model.py
megatron/model/realm_model.py
+2
-2
pretrain_bert_ict.py
pretrain_bert_ict.py
+2
-2
No files found.
megatron/model/realm_model.py
View file @
76928caa
...
@@ -52,7 +52,7 @@ class ICTBertModel(MegatronModule):
...
@@ -52,7 +52,7 @@ class ICTBertModel(MegatronModule):
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
if
self
.
use_query_model
:
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
(
)
query_types
=
torch
.
cuda
.
LongTensor
(
*
query_tokens
.
shape
).
fill_
(
0
)
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
return
query_ict_logits
else
:
else
:
...
@@ -61,7 +61,7 @@ class ICTBertModel(MegatronModule):
...
@@ -61,7 +61,7 @@ class ICTBertModel(MegatronModule):
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
if
self
.
use_block_model
:
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
(
)
block_types
=
torch
.
cuda
.
LongTensor
(
*
block_tokens
.
shape
).
fill_
(
0
)
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
return
block_ict_logits
else
:
else
:
...
...
pretrain_bert_ict.py
View file @
76928caa
...
@@ -99,8 +99,8 @@ def forward_step(data_iterator, model):
...
@@ -99,8 +99,8 @@ def forward_step(data_iterator, model):
global_batch_size
=
int
(
batch_size
*
data_parallel_size
)
global_batch_size
=
int
(
batch_size
*
data_parallel_size
)
all_logits_shape
=
(
int
(
global_batch_size
),
int
(
query_logits
.
shape
[
1
]))
all_logits_shape
=
(
int
(
global_batch_size
),
int
(
query_logits
.
shape
[
1
]))
all_query_logits
=
torch
.
zeros
(
all_logits_shape
).
type
(
query_logits
.
dtype
).
cuda
(
)
all_query_logits
=
torch
.
cuda
.
FloatTensor
(
*
all_logits_shape
).
type
(
query_logits
.
dtype
).
fill_
(
0.0
)
all_block_logits
=
all_query_logits
.
clone
()
.
cuda
()
all_block_logits
=
all_query_logits
.
clone
()
# record this processes' data
# record this processes' data
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment