Unverified Commit 071a1f51 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] clean up multimodal processor and tokenizer manager (#7624)

parent 7c0db3a6
...@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request): ...@@ -353,8 +353,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
obj = GenerateReqInput( obj = GenerateReqInput(
input_embeds=input_embeds, input_embeds=input_embeds,
sampling_params={ sampling_params={
"repetition_penalty": 1.2, "temperature": 0.0,
"temperature": 0.2,
"max_new_tokens": 512, "max_new_tokens": 512,
}, },
) )
...@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -393,16 +392,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
@app.api_route("/flush_cache", methods=["GET", "POST"]) @app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache(): async def flush_cache():
"""Flush the radix cache.""" """Flush the radix cache."""
...@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request): ...@@ -841,6 +830,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
) )
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
def _create_error_response(e): def _create_error_response(e):
return ORJSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
......
...@@ -22,17 +22,16 @@ from dataclasses import dataclass, field ...@@ -22,17 +22,16 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams
# handle serialization of Image for pydantic # Handle serialization of Image for pydantic
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image from PIL.Image import Image
else: else:
Image = Any Image = Any
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@dataclass @dataclass
class SessionParams: class SessionParams:
...@@ -182,6 +181,7 @@ class GenerateReqInput: ...@@ -182,6 +181,7 @@ class GenerateReqInput:
# Determine parallel sample count # Determine parallel sample count
if self.sampling_params is None: if self.sampling_params is None:
self.parallel_sample_num = 1 self.parallel_sample_num = 1
return
elif isinstance(self.sampling_params, dict): elif isinstance(self.sampling_params, dict):
self.parallel_sample_num = self.sampling_params.get("n", 1) self.parallel_sample_num = self.sampling_params.get("n", 1)
else: # isinstance(self.sampling_params, list): else: # isinstance(self.sampling_params, list):
......
...@@ -25,7 +25,6 @@ def get_dummy_processor(): ...@@ -25,7 +25,6 @@ def get_dummy_processor():
return DummyMultimodalProcessor() return DummyMultimodalProcessor()
@lru_cache()
def import_processors(): def import_processors():
package_name = "sglang.srt.multimodal.processors" package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
......
...@@ -180,46 +180,48 @@ class Modality(Enum): ...@@ -180,46 +180,48 @@ class Modality(Enum):
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalDataItem: class MultimodalDataItem:
""" """
A single multimodal data, from a single image/video/audio or others A single multimodal data, from a single image/video/audio or others.
We put the common fields first and the model-specific fields last.
""" """
modality: Modality modality: Modality
hash: int = None hash: int = None
pad_value: int = None pad_value: int = None
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
image_sizes: Tuple[int, int] = None image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
# the real data, pixel_values or audio_features # the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]] # data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None pixel_values: Union[torch.Tensor, np.ndarray] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
# For qwen-vl
image_grid_thw: Union[torch.Tensor, np.ndarray] = None image_grid_thw: Union[torch.Tensor, np.ndarray] = None
video_grid_thws: Union[torch.Tensor, np.ndarray] = None second_per_grid_ts: Optional[List[torch.Tensor]] = None
# For deepseek-vl
image_emb_mask: Optional[torch.Tensor] = None image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None image_spatial_crop: Optional[torch.Tensor] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# For minicpmv
# [num_images, (n, w, h)] # [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None tgt_size: Tuple[int, int] = None
# kimi-vl related # For mllama
image_grid_hws: Optional[List[torch.Tensor]] = None aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
audio_features: Union[torch.Tensor, np.ndarray] = None # For kimi-vl
audio_feature_lens: Optional[List[torch.Tensor]] = None image_grid_hws: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
# gemma3n related # For gemma3n
input_features: Optional[torch.Tensor] = None input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None input_features_mask: Optional[torch.Tensor] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod @staticmethod
def is_empty_list(l): def is_empty_list(l):
if l is None: if l is None:
...@@ -339,10 +341,6 @@ class MultimodalInputs: ...@@ -339,10 +341,6 @@ class MultimodalInputs:
image_pad_len: Optional[list] = None image_pad_len: Optional[list] = None
num_image_tokens: Optional[int] = None num_image_tokens: Optional[int] = None
# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None
# image # image
im_token_id: Optional[int] = None im_token_id: Optional[int] = None
im_start_id: Optional[int] = None im_start_id: Optional[int] = None
...@@ -358,6 +356,10 @@ class MultimodalInputs: ...@@ -358,6 +356,10 @@ class MultimodalInputs:
audio_start_id: Optional[int] = None audio_start_id: Optional[int] = None
audio_end_id: Optional[int] = None audio_end_id: Optional[int] = None
# QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
ret = MultimodalInputs( ret = MultimodalInputs(
......
...@@ -150,7 +150,9 @@ class ReqState: ...@@ -150,7 +150,9 @@ class ReqState:
# For streaming output # For streaming output
last_output_offset: int = 0 last_output_offset: int = 0
# For incremental state update. # For incremental state update.
# TODO(lianmin): do not initialize some lists if not needed.
text: str = "" text: str = ""
output_ids: List[int] = dataclasses.field(default_factory=list) output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
...@@ -199,7 +201,6 @@ class TokenizerManager: ...@@ -199,7 +201,6 @@ class TokenizerManager:
self.model_path = server_args.model_path self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name self.served_model_name = server_args.served_model_name
self.model_config = ModelConfig.from_server_args(server_args) self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
self.is_image_gen = self.model_config.is_image_gen self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len self.context_len = self.model_config.context_len
...@@ -251,19 +252,36 @@ class TokenizerManager: ...@@ -251,19 +252,36 @@ class TokenizerManager:
self.dump_requests_threshold = 1000 self.dump_requests_threshold = 1000
self.dump_request_list: List[Tuple] = [] self.dump_request_list: List[Tuple] = []
self.log_request_metadata = self.get_log_request_metadata() self.log_request_metadata = self.get_log_request_metadata()
self.asyncio_tasks = set()
self.session_futures = {} # session_id -> asyncio event
self.max_req_input_len = None
# The event to notify the weight sync is finished. # The event to notify the weight sync is finished.
self.model_update_lock = RWLock() self.model_update_lock = RWLock()
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None None
) )
self.asyncio_tasks = set()
# For session info # For pd disaggregtion
self.session_futures = {} # session_id -> asyncio event self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
# Set after scheduler is initialized # For load balancing
self.max_req_input_len = None self.current_load = 0
self.current_load_lock = asyncio.Lock()
# Metrics # Metrics
if self.enable_metrics: if self.enable_metrics:
...@@ -393,34 +411,14 @@ class TokenizerManager: ...@@ -393,34 +411,14 @@ class TokenizerManager:
] ]
) )
# For pd disaggregtion
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.transfer_backend = TransferBackend(
self.server_args.disaggregation_transfer_backend
)
# Start kv boostrap server on prefill
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# only start bootstrap server on prefill tm
kv_bootstrap_server_class = get_kv_class(
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
)
self.bootstrap_server = kv_bootstrap_server_class(
self.server_args.disaggregation_bootstrap_port
)
self.current_load = 0
self.current_load_lock = asyncio.Lock()
async def generate_request( async def generate_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
): ):
created_time = time.time() created_time = time.time()
self.auto_create_handle_loop() self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()
if isinstance(obj, EmbeddingReqInput) and self.is_generation: if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError( raise ValueError(
...@@ -428,22 +426,6 @@ class TokenizerManager: ...@@ -428,22 +426,6 @@ class TokenizerManager:
"Please add `--is-embedding` when launching the server or try another model." "Please add `--is-embedding` when launching the server or try another model."
) )
obj.normalize_batch_and_arguments()
if isinstance(obj, GenerateReqInput):
return_hidden_states = obj.return_hidden_states
has_return_hidden_states = return_hidden_states == True or (
isinstance(return_hidden_states, list) and any(return_hidden_states)
)
if (
not self.server_args.enable_return_hidden_states
and has_return_hidden_states
):
raise ValueError(
"return_hidden_states=True requires the server to be started "
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
)
if self.log_requests: if self.log_requests:
max_length, skip_names, _ = self.log_request_metadata max_length, skip_names, _ = self.log_request_metadata
logger.info( logger.info(
...@@ -451,8 +433,7 @@ class TokenizerManager: ...@@ -451,8 +433,7 @@ class TokenizerManager:
) )
async with self.model_update_lock.reader_lock: async with self.model_update_lock.reader_lock:
is_single = obj.is_single if obj.is_single:
if is_single:
tokenized_obj = await self._tokenize_one_request(obj) tokenized_obj = await self._tokenize_one_request(obj)
state = self._send_one_request(obj, tokenized_obj, created_time) state = self._send_one_request(obj, tokenized_obj, created_time)
async for response in self._wait_one_response(obj, state, request): async for response in self._wait_one_response(obj, state, request):
...@@ -514,12 +495,12 @@ class TokenizerManager: ...@@ -514,12 +495,12 @@ class TokenizerManager:
else: else:
image_inputs: Optional[Dict] = None image_inputs: Optional[Dict] = None
self._validate_token_len(obj, input_ids) self._validate_one_request(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
) )
def _validate_token_len( def _validate_one_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None: ) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length.""" """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
...@@ -548,24 +529,14 @@ class TokenizerManager: ...@@ -548,24 +529,14 @@ class TokenizerManager:
) )
raise ValueError(error_msg) raise ValueError(error_msg)
def _create_tokenized_object( if isinstance(obj, GenerateReqInput):
self, if (
obj: Union[GenerateReqInput, EmbeddingReqInput], obj.return_hidden_states
input_text: str, and not self.server_args.enable_return_hidden_states
input_ids: List[int], ):
input_embeds: Optional[Union[List[float], None]] = None, raise ValueError(
image_inputs: Optional[Dict] = None, "The server is not configured to return the hidden states. "
token_type_ids: Optional[List[int]] = None, "Please set `--enable-return-hidden-states` to enable this feature."
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
if self.is_generation:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
token_ids_logprob = obj.token_ids_logprob
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
) )
if ( if (
obj.custom_logit_processor obj.custom_logit_processor
...@@ -576,6 +547,16 @@ class TokenizerManager: ...@@ -576,6 +547,16 @@ class TokenizerManager:
"Please set `--enable-custom-logits-processor` to enable this feature." "Please set `--enable-custom-logits-processor` to enable this feature."
) )
def _create_tokenized_object(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
input_text: str,
input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None,
token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
# Parse sampling parameters # Parse sampling parameters
# Note: if there are preferred sampling params, we use them if they are not # Note: if there are preferred sampling params, we use them if they are not
# explicitly passed in sampling_params # explicitly passed in sampling_params
...@@ -589,16 +570,20 @@ class TokenizerManager: ...@@ -589,16 +570,20 @@ class TokenizerManager:
# Build return object # Build return object
if isinstance(obj, GenerateReqInput): if isinstance(obj, GenerateReqInput):
session_params = (
SessionParams(**obj.session_params) if obj.session_params else None
)
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
obj.rid, obj.rid,
input_text, input_text,
input_ids, input_ids,
image_inputs, image_inputs,
sampling_params, sampling_params,
return_logprob, obj.return_logprob,
logprob_start_len, obj.logprob_start_len,
top_logprobs_num, obj.top_logprobs_num,
token_ids_logprob, obj.token_ids_logprob,
obj.stream, obj.stream,
bootstrap_host=obj.bootstrap_host, bootstrap_host=obj.bootstrap_host,
bootstrap_port=obj.bootstrap_port, bootstrap_port=obj.bootstrap_port,
......
...@@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -98,6 +98,7 @@ class BaseMultimodalProcessor(ABC):
self._processor = _processor self._processor = _processor
self.arch = hf_config.architectures[0] self.arch = hf_config.architectures[0]
self.server_args = server_args self.server_args = server_args
# FIXME: not accurate, model and image specific # FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330 self.NUM_TOKEN_PER_FRAME = 330
......
...@@ -10,7 +10,6 @@ import torch ...@@ -10,7 +10,6 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.utils import merge_bias_tensor
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -345,3 +344,42 @@ class SamplingBatchInfo: ...@@ -345,3 +344,42 @@ class SamplingBatchInfo:
self.logit_bias = merge_bias_tensor( self.logit_bias = merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0 self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
) )
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor],
bs1: int,
bs2: int,
device: str,
default: float,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if lhs is None and rhs is None:
return None
if lhs is not None and rhs is not None:
return torch.cat([lhs, rhs])
else:
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
return torch.cat([lhs, rhs])
...@@ -4,7 +4,7 @@ from typing import List ...@@ -4,7 +4,7 @@ from typing import List
import torch import torch
from sglang.srt.utils import is_cuda, is_hip, rank0_print from sglang.srt.utils import is_cuda, is_hip, rank0_log
if is_cuda() or is_hip(): if is_cuda() or is_hip():
from sgl_kernel import ( from sgl_kernel import (
...@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient(): ...@@ -344,13 +344,13 @@ def test_build_tree_kernel_efficient():
num_verify_tokens=num_draft_token, num_verify_tokens=num_draft_token,
) )
rank0_print("=========== build tree kernel efficient ==========") rank0_log("=========== build tree kernel efficient ==========")
# rank0_print(f"{tree_mask=}", flush=True) # rank0_log(f"{tree_mask=}")
rank0_print(f"{position=}", flush=True) rank0_log(f"{position=}")
rank0_print(f"{retrive_index=}", flush=True) rank0_log(f"{retrive_index=}")
rank0_print(f"{retrive_next_token=}", flush=True) rank0_log(f"{retrive_next_token=}")
rank0_print(f"{retrive_next_sibling=}", flush=True) rank0_log(f"{retrive_next_sibling=}")
rank0_print(f"{draft_tokens=}", flush=True) rank0_log(f"{draft_tokens=}")
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14] assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
assert retrive_index.tolist() == [ assert retrive_index.tolist() == [
[0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2, 3, 4, 5, 6, 7],
......
...@@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr): ...@@ -1917,14 +1917,11 @@ def configure_ipv6(dist_init_addr):
return port, host return port, host
def rank0_print(msg: str): def rank0_log(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
print(msg, flush=True) logger.info(msg)
rank0_log = rank0_print
def get_cuda_version(): def get_cuda_version():
...@@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args): ...@@ -2344,45 +2341,6 @@ def require_mlp_sync(server_args):
return server_args.enable_dp_attention or require_gathered_buffer(server_args) return server_args.enable_dp_attention or require_gathered_buffer(server_args)
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor],
bs1: int,
bs2: int,
device: str,
default: float,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if lhs is None and rhs is None:
return None
if lhs is not None and rhs is not None:
return torch.cat([lhs, rhs])
else:
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
return torch.cat([lhs, rhs])
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]: def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
import huggingface_hub as hf import huggingface_hub as hf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment