"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "caf9e985df761413f8bbeea67eb406b86daa71a8"
Unverified Commit 554fbf93 authored by yizhang2077's avatar yizhang2077 Committed by GitHub
Browse files

[Bugfix] qwen2vl forward_extend (#1727)

parent b48edff6
...@@ -35,7 +35,6 @@ from dataclasses import dataclass ...@@ -35,7 +35,6 @@ from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch import torch
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
...@@ -134,16 +133,23 @@ class ForwardBatch: ...@@ -134,16 +133,23 @@ class ForwardBatch:
) )
elif self.forward_mode.is_extend(): elif self.forward_mode.is_extend():
for i, image_inputs in enumerate(batch.image_inputs): for i, image_inputs in enumerate(batch.image_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = (
self.extend_start_loc[i],
self.extend_seq_lens[i],
self.extend_prefix_lens[i],
)
if image_inputs is None: if image_inputs is None:
# text only # text only
mrope_positions = [[i for i in range(self.seq_lens[i])]] * 3 mrope_positions = [
[
pos
for pos in range(
extend_prefix_len, extend_prefix_len + extend_seq_len
)
]
] * 3
mrope_position_delta = 0 mrope_position_delta = 0
else: else:
extend_start_loc, extend_seq_len, extend_prefix_len = (
self.extend_start_loc[i],
self.extend_seq_lens[i],
self.extend_prefix_lens[i],
)
mrope_positions, mrope_position_delta = ( mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions( MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[ input_tokens=self.input_ids[
...@@ -163,12 +169,9 @@ class ForwardBatch: ...@@ -163,12 +169,9 @@ class ForwardBatch:
mrope_positions_list[i] = mrope_positions mrope_positions_list[i] = mrope_positions
batch.mrope_positions_delta[i].append(mrope_position_delta) batch.mrope_positions_delta[i].append(mrope_position_delta)
self.mrope_positions = torch.tensor( self.mrope_positions = torch.concat(
np.concatenate( [torch.tensor(pos, device=device) for pos in mrope_positions_list],
[np.array(pos) for pos in mrope_positions_list], axis=1,
axis=1,
),
device=device,
) )
self.mrope_positions = self.mrope_positions.to(torch.int64) self.mrope_positions = self.mrope_positions.to(torch.int64)
...@@ -177,18 +180,15 @@ class ForwardBatch: ...@@ -177,18 +180,15 @@ class ForwardBatch:
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
self.positions = (self.seq_lens - 1).to(torch.int64) self.positions = (self.seq_lens - 1).to(torch.int64)
else: else:
self.positions = torch.tensor( self.positions = torch.concat(
np.concatenate( [
[ torch.arange(prefix_len, prefix_len + extend_len, device=device)
np.arange(prefix_len, prefix_len + extend_len) for prefix_len, extend_len in zip(
for prefix_len, extend_len in zip( batch.extend_prefix_lens, batch.extend_seq_lens
batch.extend_prefix_lens, batch.extend_seq_lens )
) ],
], axis=0,
axis=0, )
),
device=device,
).to(torch.int64)
@classmethod @classmethod
def init_new( def init_new(
...@@ -213,15 +213,6 @@ class ForwardBatch: ...@@ -213,15 +213,6 @@ class ForwardBatch:
# Init position information # Init position information
if not ret.forward_mode.is_decode(): if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.image_inputs = batch.image_inputs ret.image_inputs = batch.image_inputs
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
......
...@@ -362,10 +362,6 @@ class TestQWen2VLServer(TestOpenAIVisionServer): ...@@ -362,10 +362,6 @@ class TestQWen2VLServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_mixed_batch(self):
# FIXME: Temporarily skip this test.
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment