Commit 275d4e64 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Removal of unneeded changes so that diff is smaller

parent 52a5f2f2
...@@ -19,7 +19,6 @@ import torch ...@@ -19,7 +19,6 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.model.language_model import Embedding
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
......
...@@ -56,7 +56,8 @@ class Classification(MegatronModule): ...@@ -56,7 +56,8 @@ class Classification(MegatronModule):
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids, position_ids, _, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask, extended_attention_mask,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
......
...@@ -21,7 +21,6 @@ from megatron import get_args ...@@ -21,7 +21,6 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import Embedding
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
......
...@@ -68,7 +68,8 @@ class MultipleChoice(MegatronModule): ...@@ -68,7 +68,8 @@ class MultipleChoice(MegatronModule):
attention_mask, next(self.language_model.parameters()).dtype) attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids, position_ids, _, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask, extended_attention_mask,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
......
...@@ -172,7 +172,8 @@ class IREncoderBertModel(MegatronModule): ...@@ -172,7 +172,8 @@ class IREncoderBertModel(MegatronModule):
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
lm_output, pooled_output = self.language_model( lm_output, pooled_output = self.language_model(
input_ids, position_ids, input_ids,
position_ids,
extended_attention_mask, extended_attention_mask,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment