# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, Optional import torch from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.utils import logging from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size, validate_ulysses_config, ) logger = logging.get_logger(__name__) def qwen2_flash_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): """ Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1. """ bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) ########## AlltoAll for Ulysses ########## ulysses_sp_size = get_ulysses_sequence_parallel_world_size() if ulysses_sp_size > 1: validate_ulysses_config(self.num_heads, ulysses_sp_size) # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) full_q_len = query_states.size(2) # full seq length if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to " f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " f"input in {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window else: sliding_window = None attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, full_q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=sliding_window, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) # use full_q_len to reshape attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() ########## AlltoAll for Ulysses ########## if ulysses_sp_size > 1: attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def qwen2_attn_forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: """ Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. """ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS bsz, q_len, _ = hidden_states.shape hidden_shape = (bsz, q_len, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) ########## AlltoAll for Ulysses ########## ulysses_sp_size = get_ulysses_sequence_parallel_world_size() if ulysses_sp_size > 1: validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) full_q_len = query_states.size(2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) sliding_window = None if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " "Falling back to eager attention. This warning can be removed using the argument " '`attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, # main diff with Llama **kwargs, ) attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() ########## AlltoAll for Ulysses ########## if ulysses_sp_size > 1: # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights