Commit fabd3e4e authored by Yu Yao's avatar Yu Yao Committed by Jared Casper
Browse files

ViT Backbone Tensor Shape Fix

parent 1a26b291
......@@ -234,14 +234,20 @@ class VitBackbone(MegatronModule):
token_embeddings = concatenated_tokens + \
self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
# [b, s, h] => [s, b, h]
token_embeddings = token_embeddings.transpose(0, 1).contiguous()
hidden_states = self.embedding_dropout(token_embeddings)
else:
hidden_states = input
hidden_states = self.transformer(hidden_states, None)
if self.single_token_output:
hidden_states = hidden_states[:,0,:]
if self.post_process:
# [s b h] => [b s h]
if self.single_token_output:
hidden_states = hidden_states[0]
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
return hidden_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment