# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import os from dataclasses import dataclass from typing import Optional import torch from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration, ) from transformers.utils import is_flash_attn_greater_or_equal from verl.utils.ulysses import ( gather_heads_scatter_seq, gather_seq_scatter_heads, get_ulysses_sequence_parallel_world_size, validate_ulysses_config, ) try: from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) except ImportError: flash_attn_varlen_func = None def get_rope_index( processor, input_ids: torch.Tensor, image_grid_thw: Optional[torch.Tensor] = None, video_grid_thw: Optional[torch.Tensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 """ spatial_merge_size = processor.image_processor.merge_size tokens_per_second = 2 image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) image_index, video_index = 0, 0 input_ids = input_ids[attention_mask == 1] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) vision_tokens = input_ids[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums for _ in range(image_nums + video_nums): if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: ed_image = len(input_tokens) + 1 if video_token_id in input_tokens and remain_videos > 0: ed_video = input_tokens.index(video_token_id, st) else: ed_video = len(input_tokens) + 1 if ed_image < ed_video: t, h, w = ( image_grid_thw[image_index][0], image_grid_thw[image_index][1], image_grid_thw[image_index][2], ) second_per_grid_t = 0 image_index += 1 remain_images -= 1 ed = ed_image else: t, h, w = ( video_grid_thw[video_index][0], video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0 video_index += 1 remain_videos -= 1 ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t.item(), h.item() // spatial_merge_size, w.item() // spatial_merge_size, ) text_len = ed - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) else: if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) else: position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) return position_ids def prepare_fa2_from_position_ids( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor ): query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) value = value.view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) cu_seqlens = torch.cat( ( indices_q[position_ids == 0], torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), ) ) max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) def flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, query_length: int, is_causal: bool = True, position_ids: Optional[torch.Tensor] = None, sliding_window: Optional[int] = None, use_top_left_mask: bool = False, deterministic: Optional[bool] = None, **kwargs, ): """ Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) """ causal = is_causal if not use_top_left_mask else is_causal and query_length != 1 # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). use_sliding_windows = ( _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window ) flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} if is_flash_attn_greater_or_equal("2.4.1"): if deterministic is None: deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" flash_kwargs["deterministic"] = deterministic if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): batch_size = query_states.size(0) query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( query_states, key_states, value_states, position_ids[0] ) # remove channel dimension cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=kwargs.pop("dropout", 0.0), softmax_scale=kwargs.pop("softmax_scale", None), causal=causal, **flash_kwargs, ) attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) else: attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, query_length, is_causal=is_causal, sliding_window=sliding_window, use_top_left_mask=use_top_left_mask, deterministic=deterministic, **kwargs, ) # do not pass position_ids to old flash_attention_forward return attn_output def ulysses_flash_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> tuple[torch.Tensor, None, None]: from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ulysses_sp_size = get_ulysses_sequence_parallel_world_size() if ulysses_sp_size > 1: validate_ulysses_config(self.num_heads, ulysses_sp_size) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) # (batch_size, num_head / sp_size, seq_length, head_size) full_q_len = query_states.size(2) # full_q_len = seq_length else: full_q_len = q_len # Because the input can be padded, the absolute sequence length depends on the max position id. if position_embeddings is None: cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) dropout_rate = 0.0 if not self.training else self.attention_dropout # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window else: sliding_window = None attn_output = flash_attention_forward( query_states, key_states, value_states, attention_mask, full_q_len, dropout=dropout_rate, sliding_window=sliding_window, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, position_ids=position_ids, # important: pass position ids ) # (batch_size, seq_length, num_head / sp_size, head_size) if ulysses_sp_size > 1: attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, None @dataclass class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast): log_probs: Optional[torch.FloatTensor] = None entropy: Optional[torch.FloatTensor] = None def forward_base_model( self: Qwen2VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, ) -> tuple | Qwen2VLCausalLMOutputWithPast: r""" Copy paste Qwen2VL's forward https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, " f"features {n_image_features}" ) image_mask = ( (input_ids == self.config.image_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) ) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, " f"features {n_video_features}" ) video_mask = ( (input_ids == self.config.video_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) ) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) return outputs def forward_with_torch_backend( self: Qwen2VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, temperature: float = 1.0, **loss_kwargs, ) -> tuple | Qwen2VLCausalLMOutputForPPO: from verl.utils.experimental.torch_functional import FusedLinearForPPO outputs = forward_base_model( self, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, rope_deltas=rope_deltas, cache_position=cache_position, ) hidden_states = outputs[0] if not return_dict: raise NotImplementedError("forward_with_torch_backend has to return_dict") # Loss calculations if labels is not None: rolled_labels = torch.roll(labels, shifts=-1, dims=-1) elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") fused_linear_for_ppo = FusedLinearForPPO() log_probs, entropy = fused_linear_for_ppo.forward( hidden_states=hidden_states, vocab_weights=self.lm_head.weight, input_ids=rolled_labels, temperature=temperature, ) return Qwen2VLCausalLMOutputForPPO( log_probs=log_probs, entropy=entropy, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=rope_deltas, ) def forward_with_triton_backend( self: Qwen2VLForConditionalGeneration, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, temperature: float = 1.0, **loss_kwargs, ) -> tuple | Qwen2VLCausalLMOutputForPPO: from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy outputs = forward_base_model( self, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, rope_deltas=rope_deltas, cache_position=cache_position, ) hidden_states = outputs[0] if not return_dict: raise NotImplementedError("forward_with_triton_backend has to return_dict") # Loss calculations if labels is not None: rolled_labels = torch.roll(labels, shifts=-1, dims=-1) elif input_ids is not None: rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) else: raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") log_probs, entropy = linear_cross_entropy( hidden_states, self.lm_head.weight, rolled_labels, temperature, "none", ) return Qwen2VLCausalLMOutputForPPO( log_probs=log_probs, entropy=entropy, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=rope_deltas, )