Unverified Commit 95e70575 authored by Rinat's avatar Rinat Committed by GitHub
Browse files

Make vilt, switch_transformers compatible with model parallelism (#22703)

* Update modeling_vilt.py

Vilt compatible with model parallelism

* Update modeling_switch_transformers.py

switch_transformers compatible with model parallelism
parent 89087597
...@@ -1700,6 +1700,8 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod ...@@ -1700,6 +1700,8 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
decoder_router_probs = nn.Softmax(dim=-1)(decoder_router_logits) decoder_router_probs = nn.Softmax(dim=-1)(decoder_router_logits)
decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes) decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if output_router_logits and labels is not None: if output_router_logits and labels is not None:
......
...@@ -1009,6 +1009,8 @@ class ViltForMaskedLM(ViltPreTrainedModel): ...@@ -1009,6 +1009,8 @@ class ViltForMaskedLM(ViltPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token loss_fct = CrossEntropyLoss() # -100 index = padding token
# move labels to correct device to enable PP
labels = labels.to(mlm_logits.device)
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1155,6 +1157,8 @@ class ViltForQuestionAnswering(ViltPreTrainedModel): ...@@ -1155,6 +1157,8 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1] loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) * labels.shape[1]
# see https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19 # see https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19
...@@ -1395,6 +1399,8 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel): ...@@ -1395,6 +1399,8 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1481,6 +1487,8 @@ class ViltForTokenClassification(ViltPreTrainedModel): ...@@ -1481,6 +1487,8 @@ class ViltForTokenClassification(ViltPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment