# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys from typing import Callable, Optional import torch if sys.version_info >= (3, 11): pass else: pass from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.models.llama.modeling_llama import apply_rotary_pos_emb from transformers.utils import logging from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size, validate_ulysses_config, ) logger = logging.get_logger(__name__) def llama_flash_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1]. """ output_attentions = False bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # trade off: repeat first and then all to all # key_states = repeat_kv(key_states, self.num_key_value_groups) # value_states = repeat_kv(value_states, self.num_key_value_groups) ########## AlltoAll for Ulysses ########## ulysses_sp_size = get_ulysses_sequence_parallel_world_size() if ulysses_sp_size > 1: validate_ulysses_config(self.num_heads, ulysses_sp_size) # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) full_q_len = query_states.size(2) # full seq length if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to " f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " f"input in {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, full_q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=getattr(self, "sliding_window", None), use_top_left_mask=self._flash_attn_uses_top_left_mask, is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() ########## AlltoAll for Ulysses ########## if ulysses_sp_size > 1: attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def llama_attn_forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. """ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.llama.modeling_llama import eager_attention_forward bsz, q_len, _ = hidden_states.shape query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) ########## AlltoAll for Ulysses ########## ulysses_sp_size = get_ulysses_sequence_parallel_world_size() if ulysses_sp_size > 1: validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) full_q_len = query_states.size(2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " "Falling back to eager attention. This warning can be removed using the argument " '`attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() ########## AlltoAll for Ulysses ########## if ulysses_sp_size > 1: attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights