Unverified Commit 69b3bb9a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Unify forward mode (#1360)

parent 689ff588
...@@ -60,7 +60,6 @@ import torch.distributed as dist ...@@ -60,7 +60,6 @@ import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -208,14 +207,14 @@ def extend(reqs, model_runner): ...@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
tree_cache=None, tree_cache=None,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist() next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids) batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE) sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist() next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
......
...@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module): ...@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module):
@staticmethod @staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode.is_decode():
output_top_logprobs = [] output_top_logprobs = []
max_k = max(logits_metadata.top_logprobs_nums) max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1) ret = all_logprobs.topk(max_k, dim=1)
...@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module): ...@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module):
assert isinstance(logits_metadata, LogitsMetadata) assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction # Get the last hidden states and last logits for the next token prediction
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode.is_decode():
last_index = None last_index = None
last_hidden = hidden_states last_hidden = hidden_states
else: else:
...@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module): ...@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module):
) )
else: else:
# When logprob is requested, compute the logits for all tokens. # When logprob is requested, compute the logits for all tokens.
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode.is_decode():
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
# Get the logprob of top-k tokens # Get the logprob of top-k tokens
......
...@@ -197,9 +197,9 @@ class RadixAttention(nn.Module): ...@@ -197,9 +197,9 @@ class RadixAttention(nn.Module):
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode.is_extend():
return self.extend_forward(q, k, v, input_metadata) return self.extend_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode.is_decode():
return self.decode_forward(q, k, v, input_metadata) return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
......
...@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap ...@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -334,6 +335,8 @@ class ScheduleBatch: ...@@ -334,6 +335,8 @@ class ScheduleBatch:
token_to_kv_pool: BaseTokenToKVPool token_to_kv_pool: BaseTokenToKVPool
tree_cache: BasePrefixCache tree_cache: BasePrefixCache
forward_mode: ForwardMode = None
# Batched arguments to model runner # Batched arguments to model runner
input_ids: torch.Tensor = None input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None req_pool_indices: torch.Tensor = None
...@@ -397,6 +400,8 @@ class ScheduleBatch: ...@@ -397,6 +400,8 @@ class ScheduleBatch:
return out_cache_loc return out_cache_loc
def prepare_for_extend(self, vocab_size: int): def prepare_for_extend(self, vocab_size: int):
self.forward_mode = ForwardMode.EXTEND
bs = self.batch_size() bs = self.batch_size()
reqs = self.reqs reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
...@@ -626,6 +631,8 @@ class ScheduleBatch: ...@@ -626,6 +631,8 @@ class ScheduleBatch:
return jump_forward_reqs return jump_forward_reqs
def prepare_for_decode(self, input_ids=None): def prepare_for_decode(self, input_ids=None):
self.forward_mode = ForwardMode.DECODE
if input_ids is None: if input_ids is None:
input_ids = [ input_ids = [
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1] r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
......
...@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -521,9 +520,7 @@ class ModelTpServer: ...@@ -521,9 +520,7 @@ class ModelTpServer:
if self.model_runner.is_generation: if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
sample_output, logits_output = self.model_runner.forward( sample_output, logits_output = self.model_runner.forward(batch)
batch, ForwardMode.EXTEND
)
next_token_ids = batch.check_sample_results(sample_output) next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
...@@ -588,7 +585,7 @@ class ModelTpServer: ...@@ -588,7 +585,7 @@ class ModelTpServer:
pt += req.extend_input_len pt += req.extend_input_len
else: else:
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND) logits_output = self.model_runner.forward(batch)
embeddings = logits_output.embeddings.tolist() embeddings = logits_output.embeddings.tolist()
# Check finish conditions # Check finish conditions
...@@ -699,9 +696,7 @@ class ModelTpServer: ...@@ -699,9 +696,7 @@ class ModelTpServer:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward and sample the next tokens # Forward and sample the next tokens
sample_output, logits_output = self.model_runner.forward( sample_output, logits_output = self.model_runner.forward(batch)
batch, ForwardMode.DECODE
)
next_token_ids = batch.check_sample_results(sample_output) next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
......
...@@ -25,10 +25,9 @@ import torch ...@@ -25,10 +25,9 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...@@ -41,6 +40,15 @@ class ForwardMode(IntEnum): ...@@ -41,6 +40,15 @@ class ForwardMode(IntEnum):
# Decode one token. # Decode one token.
DECODE = auto() DECODE = auto()
def is_prefill(self):
return self == ForwardMode.PREFILL
def is_extend(self):
return self == ForwardMode.EXTEND
def is_decode(self):
return self == ForwardMode.DECODE
@dataclass @dataclass
class InputMetadata: class InputMetadata:
...@@ -102,7 +110,7 @@ class InputMetadata: ...@@ -102,7 +110,7 @@ class InputMetadata:
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets position_ids_offsets = batch.position_ids_offsets
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode.is_decode():
if True: if True:
self.positions = self.seq_lens - 1 self.positions = self.seq_lens - 1
else: else:
...@@ -141,7 +149,7 @@ class InputMetadata: ...@@ -141,7 +149,7 @@ class InputMetadata:
self.positions = self.positions.to(torch.int64) self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch): def compute_extend_infos(self, batch: ScheduleBatch):
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode.is_decode():
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
else: else:
...@@ -173,10 +181,9 @@ class InputMetadata: ...@@ -173,10 +181,9 @@ class InputMetadata:
cls, cls,
model_runner: "ModelRunner", model_runner: "ModelRunner",
batch: ScheduleBatch, batch: ScheduleBatch,
forward_mode: ForwardMode,
): ):
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=batch.forward_mode,
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
batch_size=batch.batch_size(), batch_size=batch.batch_size(),
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
...@@ -194,13 +201,11 @@ class InputMetadata: ...@@ -194,13 +201,11 @@ class InputMetadata:
ret.compute_extend_infos(batch) ret.compute_extend_infos(batch)
if ( fm = batch.forward_mode
forward_mode != ForwardMode.DECODE if not fm.is_decode() or model_runner.server_args.disable_flashinfer:
or model_runner.server_args.disable_flashinfer
):
ret.total_num_tokens = int(torch.sum(ret.seq_lens)) ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if forward_mode != ForwardMode.DECODE: if not fm.is_decode():
ret.init_multimuldal_info(batch) ret.init_multimuldal_info(batch)
if model_runner.server_args.disable_flashinfer: if model_runner.server_args.disable_flashinfer:
...@@ -209,7 +214,7 @@ class InputMetadata: ...@@ -209,7 +214,7 @@ class InputMetadata:
flashinfer_use_ragged = False flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer: if not model_runner.server_args.disable_flashinfer:
if ( if (
forward_mode != ForwardMode.DECODE not fm.is_decode()
and int(torch.sum(ret.seq_lens)) > 4096 and int(torch.sum(ret.seq_lens)) > 4096
and model_runner.sliding_window_size is None and model_runner.sliding_window_size is None
): ):
...@@ -226,7 +231,7 @@ class InputMetadata: ...@@ -226,7 +231,7 @@ class InputMetadata:
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode.is_decode():
self.triton_max_extend_len = None self.triton_max_extend_len = None
else: else:
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
...@@ -239,7 +244,7 @@ class InputMetadata: ...@@ -239,7 +244,7 @@ class InputMetadata:
prefix_lens_cpu, prefix_lens_cpu,
flashinfer_use_ragged, flashinfer_use_ragged,
): ):
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode.is_decode():
prefix_lens = None prefix_lens = None
else: else:
prefix_lens = self.extend_prefix_lens prefix_lens = self.extend_prefix_lens
...@@ -339,7 +344,7 @@ def update_flashinfer_indices( ...@@ -339,7 +344,7 @@ def update_flashinfer_indices(
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
if forward_mode == ForwardMode.DECODE: if forward_mode.is_decode():
# CUDA graph uses different flashinfer_decode_wrapper # CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None: if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
...@@ -388,7 +393,7 @@ def update_flashinfer_indices( ...@@ -388,7 +393,7 @@ def update_flashinfer_indices(
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
if forward_mode == ForwardMode.DECODE: if forward_mode.is_decode():
paged_kernel_lens = torch.minimum( paged_kernel_lens = torch.minimum(
seq_lens, torch.tensor(model_runner.sliding_window_size + 1) seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
) )
...@@ -418,7 +423,7 @@ def update_flashinfer_indices( ...@@ -418,7 +423,7 @@ def update_flashinfer_indices(
kv_indices, kv_indices,
) )
if forward_mode == ForwardMode.DECODE: if forward_mode.is_decode():
# CUDA graph uses different flashinfer_decode_wrapper # CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None: if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
......
...@@ -530,11 +530,7 @@ class ModelRunner: ...@@ -530,11 +530,7 @@ class ModelRunner:
): ):
return self.cuda_graph_runner.replay(batch) return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch( input_metadata = InputMetadata.from_schedule_batch(self, batch)
self,
batch,
ForwardMode.DECODE,
)
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -542,11 +538,7 @@ class ModelRunner: ...@@ -542,11 +538,7 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch): def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch( input_metadata = InputMetadata.from_schedule_batch(self, batch)
self,
batch,
forward_mode=ForwardMode.EXTEND,
)
if self.is_generation: if self.is_generation:
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -562,11 +554,7 @@ class ModelRunner: ...@@ -562,11 +554,7 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch): def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch( input_metadata = InputMetadata.from_schedule_batch(self, batch)
self,
batch,
forward_mode=ForwardMode.EXTEND,
)
return self.model.forward( return self.model.forward(
batch.input_ids, batch.input_ids,
input_metadata.positions, input_metadata.positions,
...@@ -577,16 +565,18 @@ class ModelRunner: ...@@ -577,16 +565,18 @@ class ModelRunner:
) )
def forward( def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode self, batch: ScheduleBatch
) -> Tuple[SampleOutput, LogitsProcessorOutput]: ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend():
return self.forward_extend_multi_modal(batch) return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: elif batch.forward_mode.is_decode():
return self.forward_decode(batch) return self.forward_decode(batch)
elif forward_mode == ForwardMode.EXTEND: elif batch.forward_mode.is_extend():
return self.forward_extend(batch) return self.forward_extend(batch)
else: else:
raise ValueError(f"Invaid forward mode: {forward_mode}") raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
@lru_cache() @lru_cache()
......
...@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes: Optional[List[List[int]]] = None, image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None, image_offsets: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Got List[List[str]] extend it to List[str] # Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size # The length of the List should be equal to batch size
...@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
return self.language_model( return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds input_ids, positions, input_metadata, input_embeds=input_embeds
) )
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
image_sizes: Optional[List[List[int]]] = None, image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None, image_offsets: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Embed text inputs # Embed text inputs
...@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
return self.language_model( return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds input_ids, positions, input_metadata, input_embeds=input_embeds
) )
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment