# Copyright (c) Alibaba, Inc. and its affiliates. import math import os import sys import types from collections import OrderedDict from typing import List, Optional, Tuple import safetensors import torch import torch.nn.functional as F import transformers from packaging import version from peft import PeftModel from torch.utils.data import DataLoader from transformers import PreTrainedModel, trainer from transformers.modeling_utils import unwrap_model from swift.utils import get_logger, torchacc_trim_graph, use_torchacc logger = get_logger() # DataLoader def get_bucket_sizes(max_length: int) -> List[int]: return [max_length // 4 * (i + 1) for i in range(4)] def _get_closet_bucket(bucket_sizes, data_length): """Select the one from bucket_sizes that is closest in distance to data_length. This is required for TorchAcc. """ cloest_length = sys.maxsize for b in bucket_sizes: if b == data_length or ((b < cloest_length) and (b > data_length)): cloest_length = b if cloest_length == sys.maxsize: bucket_sizes.append(data_length) cloest_length = data_length return cloest_length def pad_and_split_batch(padding_to, input_ids, attention_mask, labels, loss_scale, max_length, tokenizer, rank, world_size): if padding_to is None: longest_len = input_ids.shape[-1] bucket_sizes = get_bucket_sizes(max_length) bucket_data_length = _get_closet_bucket(bucket_sizes, longest_len) padding_length = bucket_data_length - input_ids.shape[1] input_ids = F.pad(input_ids, (0, padding_length), 'constant', tokenizer.pad_token_id) attention_mask = F.pad(attention_mask, (0, padding_length), 'constant', 0) if loss_scale: loss_scale = F.pad(loss_scale, (0, padding_length), 'constant', 0.) labels = F.pad(labels, (0, padding_length), 'constant', -100) # manully split the batch to different DP rank. batch_size = input_ids.shape[0] // world_size if batch_size > 0: start = rank * batch_size end = (rank + 1) * batch_size input_ids = input_ids[start:end, :] attention_mask = attention_mask[start:end, :] labels = labels[start:end, :] if loss_scale: loss_scale = loss_scale[start:end, :] return input_ids, attention_mask, labels, loss_scale def ta_train_dataloader(train_dataset, data_collator, sampler, args, batch_size): # patch skip_first_batches for customized dataloader. def acc_skip_first_batches(dataloader, num_batches=0): from accelerate.data_loader import SkipBatchSampler batch_sampler = SkipBatchSampler(dataloader._loader.batch_sampler, skip_batches=num_batches) try: dataset = dataloader.dataset except AttributeError: dataset = dataloader._loader.dataset dataloader_params = { 'collate_fn': data_collator, 'num_workers': args.dataloader_num_workers, 'pin_memory': args.dataloader_pin_memory, 'persistent_workers': args.dataloader_persistent_workers, } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params['batch_sampler'] = batch_sampler dataloader_params['worker_init_fn'] = trainer.seed_worker return ta.AsyncLoader(DataLoader(dataset, **dataloader_params), args.device) trainer.skip_first_batches = acc_skip_first_batches # dataloader for TorchAcc. import torchacc as ta dataloader_params = { 'batch_size': batch_size, 'collate_fn': data_collator, 'num_workers': args.dataloader_num_workers, 'pin_memory': args.dataloader_pin_memory, 'persistent_workers': args.dataloader_persistent_workers, } if not isinstance(train_dataset, torch.utils.data.IterableDataset): dataloader_params['sampler'] = sampler dataloader_params['drop_last'] = args.dataloader_drop_last dataloader_params['worker_init_fn'] = trainer.seed_worker return ta.AsyncLoader(DataLoader(train_dataset, **dataloader_params), args.device) def ta_eval_dataloader(eval_dataset, data_collator, sampler, args): import torchacc as ta dataloader_params = { 'batch_size': args.eval_batch_size, 'collate_fn': data_collator, 'num_workers': args.dataloader_num_workers, 'pin_memory': args.dataloader_pin_memory, 'persistent_workers': args.dataloader_persistent_workers, } if not isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params['sampler'] = sampler dataloader_params['drop_last'] = args.dataloader_drop_last return ta.AsyncLoader(DataLoader(eval_dataset, **dataloader_params), args.device) def ta_test_dataloader(test_dataset, data_collator, sampler, args): import torchacc as ta dataloader_params = { 'batch_size': args.eval_batch_size, 'collate_fn': data_collator, 'num_workers': args.dataloader_num_workers, 'pin_memory': args.dataloader_pin_memory, 'persistent_workers': args.dataloader_persistent_workers, } if not isinstance(test_dataset, torch.utils.data.IterableDataset): dataloader_params['sampler'] = sampler dataloader_params['drop_last'] = args.dataloader_drop_last # We use the same batch_size as for eval. return ta.AsyncLoader(DataLoader(test_dataset, **dataloader_params), args.device) # Save/load checkpoint def consolidate_checkpoint(resume_from_checkpoint, model_name='adapter_model'): """ Consolidate the sharded TorchAcc checkpoints into a single model checkpoint. """ import torch_xla.core.xla_model as xm from torch_xla.distributed.fsdp import consolidate_sharded_state_dicts if model_name not in ('adapter_model', 'model'): logger.error('Only support PeftModel and PreTrainedModel.') return model_dir = os.path.join(resume_from_checkpoint, '0') is_pretrained_model = False if os.path.exists(os.path.join(model_dir, f'{model_name}.safetensors')): use_safetensors = True elif os.path.exists(os.path.join(model_dir, f'{model_name}.bin')): use_safetensors = False elif os.path.exists(os.path.join(model_dir, 'pytorch_model.bin')): # PreTrainedModel use 'pytorch_model.bin' and 'model.safetensors' use_safetensors = False is_pretrained_model = True else: logger.error('Cannot find checkpoint.') state_dict_list = [] if xm.is_master_ordinal(local=False) and use_safetensors: from safetensors.torch import load_file, save_file for rank in range(xm.xrt_world_size()): shard_dir = os.path.join(resume_from_checkpoint, f'{rank}') filename = os.path.join(shard_dir, f'{model_name}.safetensors') state_dict = load_file(filename, device='cpu') state_dict = OrderedDict(('_fsdp_wrapped_module.' + k, v) for k, v in state_dict.items()) state_dict_list.append(state_dict) shard_metadata = torch.load(os.path.join(model_dir, 'shard_meta.pth'), map_location='cpu') elif xm.is_master_ordinal(local=False): for rank in range(xm.xrt_world_size()): shard_dir = os.path.join(resume_from_checkpoint, f'{rank}') if not is_pretrained_model: filename = os.path.join(shard_dir, f'{model_name}.bin') else: filename = os.path.join(shard_dir, 'pytorch_model.bin') state_dict = torch.load(filename, map_location='cpu') state_dict = OrderedDict(('_fsdp_wrapped_module.' + k, v) for k, v in state_dict.items()) state_dict_list.append(state_dict) shard_metadata = torch.load(os.path.join(model_dir, 'shard_meta.pth'), map_location='cpu') if xm.is_master_ordinal(local=False): full_state_dict = consolidate_sharded_state_dicts(state_dict_list, shard_metadata) # peft will prepend "default." prefix automatically, so we remove the # "default." prefix to prevent the duplication of the prefix. full_state_dict = OrderedDict((k.replace('default.', ''), v) for k, v in full_state_dict.items()) torch.save(full_state_dict, os.path.join(resume_from_checkpoint, f'{model_name}.bin')) if model_name == 'adapter_model': config_path = os.path.join(resume_from_checkpoint, 'adapter_config.json') old_config_path = os.path.join(model_dir, 'adapter_config.json') os.system(f'cp {old_config_path} {config_path}') xm.rendezvous('ckpt_consolidation') def ta_save_optimizer_and_scheduler(optimizer, lr_scheduler, output_dir): import torch_xla.core.xla_model as xm xm.rendezvous('saving_optimizer_states') torch.save(optimizer.state_dict(), os.path.join(output_dir, f'optimizer_{xm.get_ordinal()}.pt')) torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, f'scheduler_{xm.get_ordinal()}.pt')) xm.rendezvous('saving_optimizer_states_done') def ta_load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint, device): import torch_xla.core.xla_model as xm optimizer_state = torch.load(os.path.join(checkpoint, f'optimizer_{xm.get_ordinal()}.pt'), map_location='cpu') lr_scheduler_state = torch.load(os.path.join(checkpoint, f'scheduler_{xm.get_ordinal()}.pt'), map_location='cpu') xm.send_cpu_data_to_device(optimizer_state, device) xm.send_cpu_data_to_device(lr_scheduler_state, device) optimizer.load_state_dict(optimizer_state) lr_scheduler.load_state_dict(lr_scheduler_state) return optimizer, lr_scheduler def save_ta_ddp_checkpoint(self_model, tokenizer, args, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else args.output_dir import torch_xla.core.xla_model as xm model = self_model if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(args, os.path.join(output_dir, 'training_args.bin')) xm.mark_step() # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` supported_classes = (PreTrainedModel, PeftModel) if not isinstance(model, supported_classes): if isinstance(unwrap_model(model), supported_classes): unwrap_model(model).save_pretrained( output_dir, is_main_process=args.should_save, state_dict=xm._maybe_convert_to_cpu(model.state_dict()), save_function=xm.save, safe_serialization=args.save_safetensors, ) else: logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') state_dict = xm._maybe_convert_to_cpu(model.state_dict()) if args.save_safetensors: safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors')) else: torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin')) else: model.save_pretrained( output_dir, is_main_process=args.should_save, save_function=xm.save, safe_serialization=args.save_safetensors, state_dict=xm._maybe_convert_to_cpu(model.state_dict())) if tokenizer is not None and args.should_save: tokenizer.save_pretrained(output_dir) def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir): import torch_xla.core.xla_model as xm xm.mark_step() if xm.is_master_ordinal(local=False): os.makedirs(output_dir, exist_ok=True) torch.save(args, os.path.join(output_dir, 'training_args.bin')) model = self_model._get_underlay_model().module.module supported_classes = (PreTrainedModel, PeftModel) save_safetensors = args.save_safetensors # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` xm.rendezvous('saving_checkpoint') out_dir = os.path.join(output_dir, f'{xm.get_ordinal()}') if not isinstance(model, supported_classes): if isinstance(unwrap_model(model), supported_classes): unwrap_model(model).save_pretrained( out_dir, state_dict=xm._maybe_convert_to_cpu(model.state_dict()), save_function=xm.save, safe_serialization=args.save_safetensors, ) else: logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') state_dict = xm._maybe_convert_to_cpu(model.state_dict()) if save_safetensors: safetensors.torch.save_file(state_dict, os.path.join(out_dir, 'model.safetensors')) else: torch.save(state_dict, os.path.join(out_dir, 'pytorch_model.bin')) else: model.save_pretrained( out_dir, save_function=xm.save, safe_serialization=args.save_safetensors, state_dict=xm._maybe_convert_to_cpu(model.state_dict())) # save shard_metadata for consolidation. shard_meta = self_model._get_underlay_model().get_shard_metadata() xm.save(shard_meta, os.path.join(out_dir, 'shard_meta.pth')) xm.rendezvous('saving_checkpoint_done') if tokenizer is not None and args.should_save: tokenizer.save_pretrained(output_dir, is_main_process=xm.is_master_ordinal(local=False), save_function=xm.save) def ta_trim_graph(): if use_torchacc() and torchacc_trim_graph(): import torchacc as ta ta.mark_step() # Model patch def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ if position_ids is not None: cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) else: cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def patch_acc_model(model, args): if not args.use_flash_attn: logger.warn('Currently use flash attn for torchacc.') if args.model_type.startswith('qwen1half'): model = patch_qwen2_model(model) elif args.model_type.startswith('qwen'): import torchacc as ta model = ta.patch_qwen_model(model) elif args.model_type.startswith('baichuan'): model = patch_baichuan_model(model) elif args.model_type.startswith('llama') or args.model_type.startswith('yi'): model = patch_llama_model(model) elif args.model_type.startswith('chatglm'): model = patah_chatglm_model(model) return model def patch_llama_model(model): def update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_seen_tokens: int, ): # attention_mask is not supported in TorchAcc. return None def llama_attn_forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: from torchacc.ops import flash_attn_varlen_xla import einops bsz, q_len, _ = hidden_states.size() query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) key_states = ( self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) value_states = ( self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) kv_seq_len = key_states.shape[-2] assert past_key_value is None, 'past_key_value is not supported' if version.parse(transformers.__version__) >= version.parse('4.36'): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) assert not output_attentions, 'output_attentions is not supported' if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # See https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py # if attention_mask is not None: # value_states = value_states * attention_mask.unsqueeze(1).unsqueeze(-1) q = einops.rearrange(query_states, 'b h s ... -> (b s) h ...') k = einops.rearrange(key_states, 'b h s ... -> (b s) h ...') v = einops.rearrange(value_states, 'b h s ... -> (b s) h ...') max_s = q_len cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) output = flash_attn_varlen_xla( q, k, v, cu_q_lens, cu_q_lens, max_s, max_s, 0.0, softmax_scale=None, causal=True) output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) return self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')), None, past_key_value for layer in model.model.layers: layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn) if version.parse(transformers.__version__) >= version.parse('4.40'): model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model) return model def patah_chatglm_model(model): def chatglm_apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3) rot_dim = rope_cache.shape[-2] * 2 x, x_pass = x[..., :rot_dim], x[..., rot_dim:] # truncate to support variable sizes rope_cache = rope_cache[:sq] xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], ], -1, ) x_out2 = x_out2.flatten(3) return torch.cat((x_out2, x_pass), dim=-1) def chatglm_attn_forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, **kwargs): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1, ) query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)) value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)) else: new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb) # adjust key and value for inference if kv_cache is not None: cache_k, cache_v = kv_cache key_layer = torch.cat((cache_k, key_layer), dim=0) value_layer = torch.cat((cache_v, value_layer), dim=0) if use_cache: kv_cache = (key_layer, value_layer) else: kv_cache = None if self.multi_query_attention: key_layer = key_layer.unsqueeze(-2) key_layer = key_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) value_layer = value_layer.unsqueeze(-2) value_layer = value_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) # ================================== # core attention computation # ================================== from torchacc.ops import flash_attn_varlen_qkvpacked_xla import einops query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] bsz, _, q_len, _ = query_layer.size() qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) qkv = qkv.transpose(1, 3) qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...') cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) context_layer = flash_attn_varlen_qkvpacked_xla( qkv, cu_q_lens, q_len, dropout_p=0.0, softmax_scale=None, causal=True) context_layer = einops.rearrange(context_layer, '(b s) ... -> b s ...', b=bsz) context_layer = context_layer.permute(1, 0, 2, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.core_attention.hidden_size_per_partition, ) context_layer = context_layer.reshape(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output = self.dense(context_layer) return output, kv_cache def torchacc_swiglu(x): x = torch.chunk(x, 2, dim=-1) return F.silu(x[0]).to(x[0].dtype) * x[1] # patch attention for layer in model.transformer.encoder.layers: layer.self_attention.forward = types.MethodType(chatglm_attn_forward, layer.self_attention) layer.mlp.activation_func = torchacc_swiglu return model def patch_baichuan_model(model): def baichuan_attn_forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: import einops bsz, q_len, _ = hidden_states.size() proj = self.W_pack(hidden_states) proj = (proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)) query_states = (proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) key_states = (proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) value_states = (proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None from torchacc.ops import flash_attn_varlen_xla query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) output = flash_attn_varlen_xla( q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, 0.0, softmax_scale=None, causal=True) output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) output = self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')) return output, None, past_key_value for layer in model.base_model.layers: layer.self_attn.forward = types.MethodType(baichuan_attn_forward, layer.self_attn) return model def patch_qwen2_model(model): def qwen2_attn_forward( self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' 'with a layer index.') kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 rotary_seq_len = kv_seq_len + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, '_pre_quantization_dtype'): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) from torchacc.ops import flash_attn_varlen_xla import einops q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) attn_output = flash_attn_varlen_xla( q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, dropout_rate, softmax_scale=None, causal=True) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def qwen2_forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time') elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds') if self.gradient_checkpointing and self.training: if use_cache: use_cache = False past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states, ) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, position_ids, past_key_values, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1], ) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states, ) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) from transformers.modeling_outputs import BaseModelOutputWithPast return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) for layer in model.model.layers: layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn) model.model.forward = types.MethodType(qwen2_forward, model.model) return model