"scripts/vscode:/vscode.git/clone" did not exist on "50949b9a58de65a05046b144b57f439bb877fbf1"
Unverified Commit a41cf88e authored by github-actions[bot]'s avatar github-actions[bot] Committed by GitHub
Browse files

[format] applied code formatting on changed files in pull request 4908 (#4918)


Co-authored-by: default avatargithub-actions <github-actions@github.com>
parent 4f68b3f1
......@@ -6,25 +6,20 @@ from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.ops.rms_norm import rms_norm
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaAttention,
LlamaModel,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
repeat_kv,
)
from colossalai.logging import get_dist_logger
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
flash_attn_varlen_kvpacked_func,
)
from flash_attn.ops.rms_norm import rms_norm
logger = get_dist_logger()
......@@ -65,7 +60,7 @@ def attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment