Unverified Commit 00da9065 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

feat: Support DP Attention for step3_vl (#8699)

parent 8cd34458
...@@ -11,6 +11,7 @@ import torch.nn as nn ...@@ -11,6 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.utils import is_cuda, print_info_once from sglang.srt.utils import is_cuda, print_info_once
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -365,19 +366,20 @@ class VisionAttention(nn.Module): ...@@ -365,19 +366,20 @@ class VisionAttention(nn.Module):
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
world_size = parallel_state.get_tensor_model_parallel_world_size() attn_tp_rank = get_attention_tp_rank()
self.tp_size = world_size attn_tp_size = get_attention_tp_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.tp_size = attn_tp_size
self.tp_rank = attn_tp_rank
self.dropout = dropout self.dropout = dropout
self.head_size = embed_dim // num_heads self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads projection_size, num_heads
) )
self.num_attention_heads_per_partition = dist_utils.divide( self.num_attention_heads_per_partition = dist_utils.divide(
num_dummy_heads + num_heads, world_size num_dummy_heads + num_heads, self.tp_size
) )
self.num_attention_kv_heads_per_partition = dist_utils.divide( self.num_attention_kv_heads_per_partition = dist_utils.divide(
num_dummy_heads + num_heads, world_size num_dummy_heads + num_heads, self.tp_size
) )
self.q_size = self.num_attention_heads_per_partition * self.head_size self.q_size = self.num_attention_heads_per_partition * self.head_size
...@@ -427,6 +429,8 @@ class VisionAttention(nn.Module): ...@@ -427,6 +429,8 @@ class VisionAttention(nn.Module):
total_num_kv_heads=num_dummy_heads + num_heads, total_num_kv_heads=num_dummy_heads + num_heads,
bias=qkv_bias, bias=qkv_bias,
quant_config=quant_config, quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
else: else:
...@@ -435,6 +439,8 @@ class VisionAttention(nn.Module): ...@@ -435,6 +439,8 @@ class VisionAttention(nn.Module):
output_size=3 * self.dummy_dim, output_size=3 * self.dummy_dim,
bias=qkv_bias, bias=qkv_bias,
quant_config=quant_config, quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("qkv_proj", prefix), prefix=add_prefix("qkv_proj", prefix),
) )
self.proj = RowParallelLinear( self.proj = RowParallelLinear(
...@@ -442,6 +448,8 @@ class VisionAttention(nn.Module): ...@@ -442,6 +448,8 @@ class VisionAttention(nn.Module):
output_size=embed_dim, output_size=embed_dim,
bias=proj_bias, bias=proj_bias,
quant_config=quant_config, quant_config=quant_config,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
prefix=add_prefix("proj", prefix), prefix=add_prefix("proj", prefix),
) )
......
...@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module): ...@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
# Since this is a dense model,
# the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
# This choice may not represent the optimal solution and remains open to further deliberation.
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
dim, dim,
intermediate_size, intermediate_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("gate_proj", prefix), prefix=add_prefix("gate_proj", prefix),
) )
self.act = ACT2FN[hidden_act] # quick_gelu self.act = ACT2FN[hidden_act] # quick_gelu
...@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module): ...@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module):
dim, dim,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("down_proj", prefix), prefix=add_prefix("down_proj", prefix),
) )
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, TensorType from transformers import BatchFeature, ProcessorMixin, TensorType
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import ( from sglang.srt.multimodal.processors.base_processor import (
...@@ -276,6 +276,8 @@ class Step3VLProcessor: ...@@ -276,6 +276,8 @@ class Step3VLProcessor:
super().__init__() super().__init__()
self.config = config self.config = config
if isinstance(tokenizer, ProcessorMixin):
tokenizer = tokenizer.tokenizer
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.image_size = 728 self.image_size = 728
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment