Commit b4297c6a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'yuya/vit_fix' into 'main'

ViT Backbone Tensor Shape Fix

See merge request ADLR/megatron-lm!479
parents 1a26b291 fabd3e4e
...@@ -234,14 +234,20 @@ class VitBackbone(MegatronModule): ...@@ -234,14 +234,20 @@ class VitBackbone(MegatronModule):
token_embeddings = concatenated_tokens + \ token_embeddings = concatenated_tokens + \
self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]]) self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
# [b, s, h] => [s, b, h]
token_embeddings = token_embeddings.transpose(0, 1).contiguous()
hidden_states = self.embedding_dropout(token_embeddings) hidden_states = self.embedding_dropout(token_embeddings)
else: else:
hidden_states = input hidden_states = input
hidden_states = self.transformer(hidden_states, None) hidden_states = self.transformer(hidden_states, None)
if self.single_token_output: if self.post_process:
hidden_states = hidden_states[:,0,:] # [s b h] => [b s h]
if self.single_token_output:
hidden_states = hidden_states[0]
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
return hidden_states return hidden_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment