Unverified Commit 00da9065 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

feat: Support DP Attention for step3_vl (#8699)

parent 8cd34458
......@@ -11,6 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.utils import is_cuda, print_info_once
_is_cuda = is_cuda()
......@@ -365,19 +366,20 @@ class VisionAttention(nn.Module):
**kwargs,
):
super().__init__()
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_size = world_size
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.tp_size = attn_tp_size
self.tp_rank = attn_tp_rank
self.dropout = dropout
self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_dummy_heads + num_heads, world_size
num_dummy_heads + num_heads, self.tp_size
)
self.num_attention_kv_heads_per_partition = dist_utils.divide(
num_dummy_heads + num_heads, world_size
num_dummy_heads + num_heads, self.tp_size
)
self.q_size = self.num_attention_heads_per_partition * self.head_size
......@@ -427,6 +429,8 @@ class VisionAttention(nn.Module):
total_num_kv_heads=num_dummy_heads + num_heads,
bias=qkv_bias,
quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
else:
......@@ -435,6 +439,8 @@ class VisionAttention(nn.Module):
output_size=3 * self.dummy_dim,
bias=qkv_bias,
quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.proj = RowParallelLinear(
......@@ -442,6 +448,8 @@ class VisionAttention(nn.Module):
output_size=embed_dim,
bias=proj_bias,
quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("proj", prefix),
)
......
......@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
# Since this is a dense model,
# the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
# This choice may not represent the optimal solution and remains open to further deliberation.
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.fc1 = ColumnParallelLinear(
dim,
intermediate_size,
bias=bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("gate_proj", prefix),
)
self.act = ACT2FN[hidden_act] # quick_gelu
......@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module):
dim,
bias=bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("down_proj", prefix),
)
......
......@@ -8,7 +8,7 @@ import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, TensorType
from transformers import BatchFeature, ProcessorMixin, TensorType
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import (
......@@ -276,6 +276,8 @@ class Step3VLProcessor:
super().__init__()
self.config = config
if isinstance(tokenizer, ProcessorMixin):
tokenizer = tokenizer.tokenizer
self.tokenizer = tokenizer
self.image_size = 728
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment