Unverified Commit 954df466 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Support SVG Attention (#374)

parent 51e102fe
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "svg_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 5,
"sample_shift": 3,
"enable_cfg": true,
"cpu_offload": false
}
from .flash_attn import * from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .radial_attn import * from .radial_attn import RadialAttnWeight
from .ring_attn import * from .ring_attn import RingAttnWeight
from .sage_attn import * from .sage_attn import SageAttn2Weight
from .torch_sdpa import * from .svg_attn import SvgAttnWeight
from .ulysses_attn import * from .torch_sdpa import TorchSDPAWeight
from .ulysses_attn import UlyssesAttnWeight
import math
from functools import lru_cache
from math import ceil
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from loguru import logger
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@triton.jit
def wan_hidden_states_placement_kernel(
hidden_states_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
hidden_states_out_ptr, # [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr, # [cfg, num_heads]
hidden_states_stride_b,
hidden_states_stride_h,
hidden_states_stride_s,
hidden_states_stride_d,
mask_idx_stride_b,
mask_idx_stride_h,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
context_length: tl.constexpr,
num_frame: tl.constexpr,
frame_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Copy hidden_states to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg = tl.program_id(0)
head = tl.program_id(1)
block_id = tl.program_id(2)
start_id = block_id * BLOCK_SIZE
end_id = start_id + BLOCK_SIZE
end_id = tl.where(end_id > seq_len, seq_len, end_id)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
offset_mask = offset_token < seq_len
offset_d = tl.arange(0, head_dim)
if is_temporal:
patch_id = offset_token // num_frame
frame_id = offset_token - patch_id * num_frame
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id)
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states = hidden_states_ptr + offset_load
offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states_out = hidden_states_out_ptr + offset_store
# Maybe tune the pipeline here
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
else:
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states = hidden_states_ptr + offset_load
offset_store = offset_load
offset_hidden_states_out = hidden_states_out_ptr + offset_store
# Maybe tune the pipeline here
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
def wan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size):
cfg, num_heads, seq_len, head_dim = hidden_states.shape
BLOCK_SIZE = 128
assert seq_len == context_length + num_frame * frame_size
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
wan_hidden_states_placement_kernel[grid](
hidden_states,
hidden_states_out,
best_mask_idx,
hidden_states.stride(0),
hidden_states.stride(1),
hidden_states.stride(2),
hidden_states.stride(3),
best_mask_idx.stride(0),
best_mask_idx.stride(1),
seq_len,
head_dim,
context_length,
num_frame,
frame_size,
BLOCK_SIZE,
)
return hidden_states_out
@triton.jit
def wan_sparse_head_placement_kernel(
query_ptr,
key_ptr,
value_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
query_out_ptr,
key_out_ptr,
value_out_ptr, # [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr, # [cfg, num_heads]
query_stride_b,
query_stride_h,
query_stride_s,
query_stride_d,
mask_idx_stride_b,
mask_idx_stride_h,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
context_length: tl.constexpr,
num_frame: tl.constexpr,
frame_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Copy query, key, value to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg = tl.program_id(0)
head = tl.program_id(1)
block_id = tl.program_id(2)
start_id = block_id * BLOCK_SIZE
end_id = start_id + BLOCK_SIZE
end_id = tl.where(end_id > seq_len, seq_len, end_id)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
offset_mask = offset_token < seq_len
offset_d = tl.arange(0, head_dim)
if is_temporal:
frame_id = offset_token // frame_size
patch_id = offset_token - frame_id * frame_size
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id)
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query = query_ptr + offset_load
offset_key = key_ptr + offset_load
offset_value = value_ptr + offset_load
offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query_out = query_out_ptr + offset_store
offset_key_out = key_out_ptr + offset_store
offset_value_out = value_out_ptr + offset_store
# Maybe tune the pipeline here
query = tl.load(offset_query, mask=offset_mask[:, None])
tl.store(offset_query_out, query, mask=offset_mask[:, None])
key = tl.load(offset_key, mask=offset_mask[:, None])
tl.store(offset_key_out, key, mask=offset_mask[:, None])
value = tl.load(offset_value, mask=offset_mask[:, None])
tl.store(offset_value_out, value, mask=offset_mask[:, None])
else:
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query = query_ptr + offset_load
offset_key = key_ptr + offset_load
offset_value = value_ptr + offset_load
offset_store = offset_load
offset_query_out = query_out_ptr + offset_store
offset_key_out = key_out_ptr + offset_store
offset_value_out = value_out_ptr + offset_store
# Maybe tune the pipeline here
query = tl.load(offset_query, mask=offset_mask[:, None])
tl.store(offset_query_out, query, mask=offset_mask[:, None])
key = tl.load(offset_key, mask=offset_mask[:, None])
tl.store(offset_key_out, key, mask=offset_mask[:, None])
value = tl.load(offset_value, mask=offset_mask[:, None])
tl.store(offset_value_out, value, mask=offset_mask[:, None])
def wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
cfg, num_heads, seq_len, head_dim = query.shape
BLOCK_SIZE = 128
assert seq_len == context_length + num_frame * frame_size
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
wan_sparse_head_placement_kernel[grid](
query,
key,
value,
query_out,
key_out,
value_out,
best_mask_idx,
query.stride(0),
query.stride(1),
query.stride(2),
query.stride(3),
best_mask_idx.stride(0),
best_mask_idx.stride(1),
seq_len,
head_dim,
context_length,
num_frame,
frame_size,
BLOCK_SIZE,
)
def generate_temporal_head_mask_mod(context_length: int = 226, prompt_length: int = 226, num_frames: int = 13, token_per_frame: int = 1350, mul: int = 2):
def round_to_multiple(idx):
return ceil(idx / 128) * 128
def temporal_mask_mod(b, h, q_idx, kv_idx):
two_frame = round_to_multiple(mul * token_per_frame)
temporal_head_mask = torch.abs(q_idx - kv_idx) <= two_frame
# return temporal_head_mask
first_frame_mask = kv_idx < token_per_frame
video_mask = first_frame_mask | temporal_head_mask
return video_mask
return temporal_mask_mod
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
return block_mask
def prepare_flexattention(cfg_size, num_head, head_dim, dtype, device, context_length, prompt_length, num_frame, frame_size, diag_width=1, multiplier=2):
assert diag_width == multiplier, f"{diag_width} is not equivalent to {multiplier}"
seq_len = context_length + num_frame * frame_size
mask_mod = generate_temporal_head_mask_mod(context_length, prompt_length, num_frame, frame_size, mul=multiplier)
block_mask = create_block_mask_cached(mask_mod, None, None, seq_len, seq_len, device=device, _compile=True)
return block_mask
def sparsity_to_width(sparsity, context_length, num_frame, frame_size):
seq_len = context_length + num_frame * frame_size
total_elements = seq_len**2
sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements
width = seq_len * (1 - math.sqrt(1 - sparsity))
width_frame = width / frame_size
return width_frame
def get_attention_mask(mask_name, sample_mse_max_row, context_length, num_frame, frame_size):
attention_mask = torch.zeros((context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu")
# TODO: fix hard coded mask
if mask_name == "spatial":
pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
block_size, block_thres = 128, frame_size * 2
num_block = math.ceil(num_frame * frame_size / block_size)
for i in range(num_block):
for j in range(num_block):
if abs(i - j) < block_thres // block_size:
pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
attention_mask = pixel_attn_mask
else:
pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
block_size, block_thres = 128, frame_size * 2
num_block = math.ceil(num_frame * frame_size / block_size)
for i in range(num_block):
for j in range(num_block):
if abs(i - j) < block_thres // block_size:
pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3, 2).reshape(frame_size * num_frame, frame_size * num_frame)
attention_mask = pixel_attn_mask
attention_mask = attention_mask[:sample_mse_max_row].cuda()
return attention_mask
@ATTN_WEIGHT_REGISTER("svg_attn")
class SvgAttnWeight(AttnWeightTemplate):
head_num = None
head_dim = None
sample_mse_max_row = None
num_sampled_rows = None
context_length = None
num_frame = None
frame_size = None
sparsity = None
mask_name_list = ["spatial", "temporal"]
attention_masks = None
block_mask = None
@classmethod
def prepare(cls, head_num, head_dim, sample_mse_max_row, num_sampled_rows, context_length, sparsity):
cls.head_num = head_num
cls.head_dim = head_dim
cls.sample_mse_max_row = sample_mse_max_row
cls.num_sampled_rows = num_sampled_rows
cls.context_length = context_length
cls.sparsity = sparsity
torch._dynamo.config.cache_size_limit = 192 * 3
torch._dynamo.config.accumulated_cache_size_limit = 192 * 3
logger.info(
f"SvgAttnWeight Prepare: head_num={head_num}, head_dim={head_dim}, sample_mse_max_row={sample_mse_max_row}, num_sampled_rows={num_sampled_rows}, context_length={context_length}, sparsity={sparsity}"
)
def __init__(self):
self.config = {}
self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
@classmethod
def prepare_mask(cls, num_frame, frame_size):
# Use class attributes so updates affect all instances of this class
if num_frame == cls.num_frame and frame_size == cls.frame_size:
return
cls.num_frame = num_frame
cls.frame_size = frame_size
cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, num_frame, frame_size) for mask_name in cls.mask_name_list]
multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, num_frame, frame_size)
cls.block_mask = prepare_flexattention(
1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, num_frame, frame_size, diag_width=diag_width, multiplier=multiplier
)
logger.info(f"SvgAttnWeight Update: num_frame={num_frame}, frame_size={frame_size}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size()
num_frame = 21
self.prepare_mask(num_frame=num_frame, frame_size=seq_len // num_frame)
sampled_mses = self.sample_mse(q, k, v)
best_mask_idx = torch.argmin(sampled_mses, dim=0)
output_hidden_states = torch.zeros_like(q)
query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
query_out, key_out, value_out = self.fast_sparse_head_placement(q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.num_frame, self.frame_size)
hidden_states = self.sparse_attention(query_out, key_out, value_out)
wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.num_frame, self.frame_size)
return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
def fast_sparse_head_placement(self, query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
return query_out, key_out, value_out
def sample_mse(self, query, key, value):
cfg, num_heads, seq_len, dim = query.size()
num_sampled_rows = min(self.num_sampled_rows, seq_len)
sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,))
sampled_q = query[:, :, sampled_rows, :]
sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5)
sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1)
sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value) # (1, seq_len, dim)
sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype)
# Only have Tri-diagonal and Striped
for mask_idx, attn_mask in enumerate(self.attention_masks):
sampled_attention_mask = attn_mask[sampled_rows, :]
sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf"))
sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1)
sampled_hidden_states = torch.matmul(sampled_attn_weights, value)
mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
sampled_mses[mask_idx] = mse
return sampled_mses
if __name__ == "__main__":
q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
SvgAttnWeight.prepare(head_num=40, head_dim=128, sample_mse_max_row=10000, num_sampled_rows=64, context_length=0, sparsity=0.25)
svg_attn = SvgAttnWeight()
print("SvgAttnWeight initialized.")
out = svg_attn.apply(q, k, v)
print(f"out: {out.shape}, {out.dtype}, {out.device}")
...@@ -29,7 +29,6 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -29,7 +29,6 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
...@@ -139,7 +138,6 @@ class QwenImageCrossAttention(WeightModule): ...@@ -139,7 +138,6 @@ class QwenImageCrossAttention(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.attn_type = config.get("attn_type", "flash_attn3") self.attn_type = config.get("attn_type", "flash_attn3")
self.heads = config["attention_out_dim"] // config["attention_dim_head"] self.heads = config["attention_out_dim"] // config["attention_dim_head"]
......
import os import os
import torch
from safetensors import safe_open from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
...@@ -56,7 +55,6 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -56,7 +55,6 @@ class WanTransformerAttentionBlock(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
...@@ -108,7 +106,6 @@ class WanSelfAttention(WeightModule): ...@@ -108,7 +106,6 @@ class WanSelfAttention(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
...@@ -185,16 +182,17 @@ class WanSelfAttention(WeightModule): ...@@ -185,16 +182,17 @@ class WanSelfAttention(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
if self.sparge: attention_weights_cls = ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" if self.config["self_attn_1_type"] == "svg_attn":
self.add_module( attention_weights_cls.prepare(
"self_attn_1", head_num=self.config["num_heads"],
ATTN_WEIGHT_REGISTER["Sparge"](f"{block_prefix}.{self.block_index}"), head_dim=self.config["dim"] // self.config["num_heads"],
sample_mse_max_row=self.config.get("svg_sample_mse_max_row", 10000),
num_sampled_rows=self.config.get("svg_num_sampled_rows", 64),
context_length=self.config.get("svg_context_length", 0),
sparsity=self.config.get("svg_sparsity", 0.25),
) )
sparge_ckpt = torch.load(self.config["sparge_ckpt"]) self.add_module("self_attn_1", attention_weights_cls())
self.self_attn_1.load(sparge_ckpt)
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")]()) self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")]())
......
import math
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Scaled dot-product attention
Args:
Q: Query tensor [batch_size, num_heads, seq_len, d_k]
K: Key tensor [batch_size, num_heads, seq_len, d_k]
V: Value tensor [batch_size, num_heads, seq_len, d_k]
mask: Attention mask (0 indicates positions to mask, 1 indicates positions to keep)
Returns:
output: Attention output
attention_weights: Attention weights
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask_value = torch.where(mask == 0, torch.tensor(-float("inf")), torch.tensor(0.0))
scores = scores + mask_value
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, scores, attention_weights
def draw_matrix(weights, save_path):
plt.imshow(weights, aspect="auto", cmap="viridis")
plt.colorbar()
plt.savefig(save_path)
plt.close()
def get_qkv_subset(x, head_index, token_start, token_end):
"""
x : [seq_len, num_heads, head_dim]
return: [batch_size, num_heads, seq_len, head_dim]
batch_size = 1, num_heads = 1, seq_len = token_end - token_start
"""
x = x[token_start:token_end, head_index, :] # [seq_len, head_dim]
x = x.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim]
return x
def draw_attention_weights(q, k, v, head_index, token_start, token_end, save_path):
"""
q k v : [seq_len, num_heads, head_dim]
"""
q_vis = get_qkv_subset(q, head_index=head_index, token_start=token_start, token_end=token_end)
k_vis = get_qkv_subset(k, head_index=head_index, token_start=token_start, token_end=token_end)
v_vis = get_qkv_subset(v, head_index=head_index, token_start=token_start, token_end=token_end)
output, scores, attention_weights = scaled_dot_product_attention(q_vis, k_vis, v_vis, mask=None)
draw_matrix(scores[0][0].float().cpu().numpy(), save_path)
print(f"Saved to {save_path}")
if __name__ == "__main__":
seq_len = 10
num_heads = 4
head_dim = 8
q = torch.randn(seq_len, num_heads, head_dim)
k = torch.randn(seq_len, num_heads, head_dim)
v = torch.randn(seq_len, num_heads, head_dim)
draw_attention_weights(q, k, v, head_index=0, token_start=0, token_end=10, save_path="scores.png")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment