Unverified Commit 4033ea71 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

make build_mpt_alibi_tensor a method of MptModel so that deepspeed co… (#25193)



make build_mpt_alibi_tensor a method of MptModel so that deepspeed could override it to make autoTP work
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 0fd8d2aa
...@@ -413,6 +413,9 @@ class MptModel(MptPreTrainedModel): ...@@ -413,6 +413,9 @@ class MptModel(MptPreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.wte return self.wte
def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None):
return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device)
def _prepare_attn_mask( def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor: ) -> torch.BoolTensor:
...@@ -507,7 +510,7 @@ class MptModel(MptPreTrainedModel): ...@@ -507,7 +510,7 @@ class MptModel(MptPreTrainedModel):
else: else:
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
alibi = build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device) alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
causal_mask = self._prepare_attn_mask( causal_mask = self._prepare_attn_mask(
attention_mask, attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment