Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a358e4df
Unverified
Commit
a358e4df
authored
Feb 01, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 01, 2026
Browse files
[Refactor] Make Renderer an abstract class (#33479)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
07978117
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
49 additions
and
50 deletions
+49
-50
vllm/engine/protocol.py
vllm/engine/protocol.py
+2
-2
vllm/renderers/__init__.py
vllm/renderers/__init__.py
+2
-2
vllm/renderers/deepseek_v32.py
vllm/renderers/deepseek_v32.py
+4
-6
vllm/renderers/grok2.py
vllm/renderers/grok2.py
+4
-6
vllm/renderers/hf.py
vllm/renderers/hf.py
+4
-5
vllm/renderers/mistral.py
vllm/renderers/mistral.py
+4
-6
vllm/renderers/protocol.py
vllm/renderers/protocol.py
+16
-8
vllm/renderers/registry.py
vllm/renderers/registry.py
+3
-3
vllm/renderers/terratorch.py
vllm/renderers/terratorch.py
+4
-6
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+2
-2
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+2
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+2
-2
No files found.
vllm/engine/protocol.py
View file @
a358e4df
...
@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
...
@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.plugins.io_processors
import
IOProcessor
from
vllm.plugins.io_processors
import
IOProcessor
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
Renderer
Like
from
vllm.renderers
import
Base
Renderer
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
...
@@ -28,7 +28,7 @@ class EngineClient(ABC):
...
@@ -28,7 +28,7 @@ class EngineClient(ABC):
@
property
@
property
@
abstractmethod
@
abstractmethod
def
renderer
(
self
)
->
Renderer
Like
:
...
def
renderer
(
self
)
->
Base
Renderer
:
...
@
property
@
property
@
abstractmethod
@
abstractmethod
...
...
vllm/renderers/__init__.py
View file @
a358e4df
...
@@ -2,11 +2,11 @@
...
@@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.params
import
ChatParams
,
TokenizeParams
,
merge_kwargs
from
.params
import
ChatParams
,
TokenizeParams
,
merge_kwargs
from
.protocol
import
Renderer
Like
from
.protocol
import
Base
Renderer
from
.registry
import
RendererRegistry
,
renderer_from_config
from
.registry
import
RendererRegistry
,
renderer_from_config
__all__
=
[
__all__
=
[
"Renderer
Like
"
,
"
Base
Renderer"
,
"RendererRegistry"
,
"RendererRegistry"
,
"renderer_from_config"
,
"renderer_from_config"
,
"ChatParams"
,
"ChatParams"
,
...
...
vllm/renderers/deepseek_v32.py
View file @
a358e4df
...
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
...
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
from
vllm.tokenizers.deepseek_v32
import
DeepseekV32Tokenizer
from
vllm.tokenizers.deepseek_v32
import
DeepseekV32Tokenizer
from
.params
import
ChatParams
from
.params
import
ChatParams
from
.protocol
import
Renderer
Like
from
.protocol
import
Base
Renderer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
DeepseekV32Renderer
(
Renderer
Like
):
class
DeepseekV32Renderer
(
Base
Renderer
):
@
classmethod
@
classmethod
def
from_config
(
def
from_config
(
cls
,
cls
,
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
"Renderer
Like
"
:
)
->
"
Base
Renderer"
:
return
cls
(
config
,
tokenizer_kwargs
)
return
cls
(
config
,
tokenizer_kwargs
)
def
__init__
(
def
__init__
(
...
@@ -34,9 +34,7 @@ class DeepseekV32Renderer(RendererLike):
...
@@ -34,9 +34,7 @@ class DeepseekV32Renderer(RendererLike):
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
(
config
)
self
.
config
=
config
if
config
.
skip_tokenizer_init
:
if
config
.
skip_tokenizer_init
:
tokenizer
=
None
tokenizer
=
None
...
...
vllm/renderers/grok2.py
View file @
a358e4df
...
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
...
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
from
vllm.tokenizers.grok2
import
Grok2Tokenizer
from
vllm.tokenizers.grok2
import
Grok2Tokenizer
from
.params
import
ChatParams
from
.params
import
ChatParams
from
.protocol
import
Renderer
Like
from
.protocol
import
Base
Renderer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
Grok2Renderer
(
Renderer
Like
):
class
Grok2Renderer
(
Base
Renderer
):
@
classmethod
@
classmethod
def
from_config
(
def
from_config
(
cls
,
cls
,
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
"Renderer
Like
"
:
)
->
"
Base
Renderer"
:
return
cls
(
config
,
tokenizer_kwargs
)
return
cls
(
config
,
tokenizer_kwargs
)
def
__init__
(
def
__init__
(
...
@@ -34,9 +34,7 @@ class Grok2Renderer(RendererLike):
...
@@ -34,9 +34,7 @@ class Grok2Renderer(RendererLike):
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
(
config
)
self
.
config
=
config
if
config
.
skip_tokenizer_init
:
if
config
.
skip_tokenizer_init
:
tokenizer
=
None
tokenizer
=
None
...
...
vllm/renderers/hf.py
View file @
a358e4df
...
@@ -34,7 +34,7 @@ from vllm.transformers_utils.processor import cached_get_processor
...
@@ -34,7 +34,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from
vllm.utils.func_utils
import
supports_kw
from
vllm.utils.func_utils
import
supports_kw
from
.params
import
ChatParams
from
.params
import
ChatParams
from
.protocol
import
Renderer
Like
from
.protocol
import
Base
Renderer
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalUUIDDict
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalUUIDDict
...
@@ -584,13 +584,13 @@ def replace_vision_chunk_video_placeholder(
...
@@ -584,13 +584,13 @@ def replace_vision_chunk_video_placeholder(
return
prompt_raw
return
prompt_raw
class
HfRenderer
(
Renderer
Like
):
class
HfRenderer
(
Base
Renderer
):
@
classmethod
@
classmethod
def
from_config
(
def
from_config
(
cls
,
cls
,
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
"Renderer
Like
"
:
)
->
"
Base
Renderer"
:
return
cls
(
config
,
tokenizer_kwargs
)
return
cls
(
config
,
tokenizer_kwargs
)
def
__init__
(
def
__init__
(
...
@@ -598,9 +598,8 @@ class HfRenderer(RendererLike):
...
@@ -598,9 +598,8 @@ class HfRenderer(RendererLike):
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
(
config
)
self
.
config
=
config
self
.
use_unified_vision_chunk
=
getattr
(
self
.
use_unified_vision_chunk
=
getattr
(
config
.
hf_config
,
"use_unified_vision_chunk"
,
False
config
.
hf_config
,
"use_unified_vision_chunk"
,
False
)
)
...
...
vllm/renderers/mistral.py
View file @
a358e4df
...
@@ -17,7 +17,7 @@ from vllm.tokenizers.mistral import MistralTokenizer
...
@@ -17,7 +17,7 @@ from vllm.tokenizers.mistral import MistralTokenizer
from
vllm.utils.async_utils
import
make_async
from
vllm.utils.async_utils
import
make_async
from
.params
import
ChatParams
from
.params
import
ChatParams
from
.protocol
import
Renderer
Like
from
.protocol
import
Base
Renderer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -49,13 +49,13 @@ def safe_apply_chat_template(
...
@@ -49,13 +49,13 @@ def safe_apply_chat_template(
raise
ValueError
(
str
(
e
))
from
e
raise
ValueError
(
str
(
e
))
from
e
class
MistralRenderer
(
Renderer
Like
):
class
MistralRenderer
(
Base
Renderer
):
@
classmethod
@
classmethod
def
from_config
(
def
from_config
(
cls
,
cls
,
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
"Renderer
Like
"
:
)
->
"
Base
Renderer"
:
return
cls
(
config
,
tokenizer_kwargs
)
return
cls
(
config
,
tokenizer_kwargs
)
def
__init__
(
def
__init__
(
...
@@ -63,9 +63,7 @@ class MistralRenderer(RendererLike):
...
@@ -63,9 +63,7 @@ class MistralRenderer(RendererLike):
config
:
ModelConfig
,
config
:
ModelConfig
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
(
config
)
self
.
config
=
config
if
config
.
skip_tokenizer_init
:
if
config
.
skip_tokenizer_init
:
tokenizer
=
None
tokenizer
=
None
...
...
vllm/renderers/protocol.py
View file @
a358e4df
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
asyncio
from
typing
import
TYPE_CHECKING
,
Any
,
Protocol
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Any
from
vllm.inputs
import
EmbedsPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
EmbedsPrompt
,
TextPrompt
,
TokensPrompt
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
...
@@ -19,19 +20,26 @@ if TYPE_CHECKING:
...
@@ -19,19 +20,26 @@ if TYPE_CHECKING:
)
)
class
RendererLike
(
Protocol
):
class
BaseRenderer
(
ABC
):
config
:
"ModelConfig"
_async_tokenizer
:
AsyncMicrobatchTokenizer
@
classmethod
@
classmethod
@
abstractmethod
def
from_config
(
def
from_config
(
cls
,
cls
,
config
:
"ModelConfig"
,
config
:
"ModelConfig"
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
"Renderer
Like
"
:
)
->
"
Base
Renderer"
:
raise
NotImplementedError
raise
NotImplementedError
def
__init__
(
self
,
config
:
"ModelConfig"
)
->
None
:
super
().
__init__
()
self
.
config
=
config
# Lazy initialization since offline LLM doesn't use async
self
.
_async_tokenizer
:
AsyncMicrobatchTokenizer
|
None
=
None
@
property
@
property
@
abstractmethod
def
tokenizer
(
self
)
->
TokenizerLike
|
None
:
def
tokenizer
(
self
)
->
TokenizerLike
|
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -43,8 +51,7 @@ class RendererLike(Protocol):
...
@@ -43,8 +51,7 @@ class RendererLike(Protocol):
return
tokenizer
return
tokenizer
def
get_async_tokenizer
(
self
)
->
AsyncMicrobatchTokenizer
:
def
get_async_tokenizer
(
self
)
->
AsyncMicrobatchTokenizer
:
# Lazy initialization since offline LLM doesn't use async
if
self
.
_async_tokenizer
is
None
:
if
not
hasattr
(
self
,
"_async_tokenizer"
):
self
.
_async_tokenizer
=
AsyncMicrobatchTokenizer
(
self
.
get_tokenizer
())
self
.
_async_tokenizer
=
AsyncMicrobatchTokenizer
(
self
.
get_tokenizer
())
return
self
.
_async_tokenizer
return
self
.
_async_tokenizer
...
@@ -104,6 +111,7 @@ class RendererLike(Protocol):
...
@@ -104,6 +111,7 @@ class RendererLike(Protocol):
)
->
list
[
TextPrompt
|
TokensPrompt
|
EmbedsPrompt
]:
)
->
list
[
TextPrompt
|
TokensPrompt
|
EmbedsPrompt
]:
return
self
.
render_completions
(
prompt_input
,
prompt_embeds
)
return
self
.
render_completions
(
prompt_input
,
prompt_embeds
)
@
abstractmethod
def
render_messages
(
def
render_messages
(
self
,
self
,
messages
:
list
[
"ChatCompletionMessageParam"
],
messages
:
list
[
"ChatCompletionMessageParam"
],
...
...
vllm/renderers/registry.py
View file @
a358e4df
...
@@ -7,7 +7,7 @@ from vllm.logger import init_logger
...
@@ -7,7 +7,7 @@ from vllm.logger import init_logger
from
vllm.tokenizers.registry
import
tokenizer_args_from_config
from
vllm.tokenizers.registry
import
tokenizer_args_from_config
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
.protocol
import
Renderer
Like
from
.protocol
import
Base
Renderer
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
...
@@ -43,7 +43,7 @@ class RendererRegistry:
...
@@ -43,7 +43,7 @@ class RendererRegistry:
return
None
return
None
def
load_renderer_cls
(
self
,
renderer_mode
:
str
)
->
type
[
Renderer
Like
]:
def
load_renderer_cls
(
self
,
renderer_mode
:
str
)
->
type
[
Base
Renderer
]:
if
renderer_mode
not
in
self
.
renderers
:
if
renderer_mode
not
in
self
.
renderers
:
raise
ValueError
(
f
"No renderer registered for
{
renderer_mode
=
!
r
}
."
)
raise
ValueError
(
f
"No renderer registered for
{
renderer_mode
=
!
r
}
."
)
...
@@ -57,7 +57,7 @@ class RendererRegistry:
...
@@ -57,7 +57,7 @@ class RendererRegistry:
renderer_mode
:
str
,
renderer_mode
:
str
,
config
:
"ModelConfig"
,
config
:
"ModelConfig"
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
Renderer
Like
:
)
->
Base
Renderer
:
renderer_cls
=
self
.
load_renderer_cls
(
renderer_mode
)
renderer_cls
=
self
.
load_renderer_cls
(
renderer_mode
)
return
renderer_cls
.
from_config
(
config
,
tokenizer_kwargs
)
return
renderer_cls
.
from_config
(
config
,
tokenizer_kwargs
)
...
...
vllm/renderers/terratorch.py
View file @
a358e4df
...
@@ -14,24 +14,22 @@ from vllm.logger import init_logger
...
@@ -14,24 +14,22 @@ from vllm.logger import init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
.params
import
ChatParams
from
.params
import
ChatParams
from
.protocol
import
Renderer
Like
from
.protocol
import
Base
Renderer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
TerratorchRenderer
(
Renderer
Like
):
class
TerratorchRenderer
(
Base
Renderer
):
@
classmethod
@
classmethod
def
from_config
(
def
from_config
(
cls
,
cls
,
config
:
"ModelConfig"
,
config
:
"ModelConfig"
,
tokenizer_kwargs
:
dict
[
str
,
Any
],
tokenizer_kwargs
:
dict
[
str
,
Any
],
)
->
"Renderer
Like
"
:
)
->
"
Base
Renderer"
:
return
cls
(
config
)
return
cls
(
config
)
def
__init__
(
self
,
config
:
ModelConfig
)
->
None
:
def
__init__
(
self
,
config
:
ModelConfig
)
->
None
:
super
().
__init__
()
super
().
__init__
(
config
)
self
.
config
=
config
if
not
config
.
skip_tokenizer_init
:
if
not
config
.
skip_tokenizer_init
:
raise
ValueError
(
"Terratorch renderer requires `skip_tokenizer_init=True`"
)
raise
ValueError
(
"Terratorch renderer requires `skip_tokenizer_init=True`"
)
...
...
vllm/v1/engine/async_llm.py
View file @
a358e4df
...
@@ -24,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...
@@ -24,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from
vllm.outputs
import
STREAM_FINISHED
,
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
STREAM_FINISHED
,
PoolingRequestOutput
,
RequestOutput
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
Renderer
Like
,
merge_kwargs
from
vllm.renderers
import
Base
Renderer
,
merge_kwargs
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
...
@@ -844,7 +844,7 @@ class AsyncLLM(EngineClient):
...
@@ -844,7 +844,7 @@ class AsyncLLM(EngineClient):
return
self
.
input_processor
.
get_tokenizer
()
return
self
.
input_processor
.
get_tokenizer
()
@
property
@
property
def
renderer
(
self
)
->
Renderer
Like
:
def
renderer
(
self
)
->
Base
Renderer
:
return
self
.
input_processor
.
renderer
return
self
.
input_processor
.
renderer
async
def
is_tracing_enabled
(
self
)
->
bool
:
async
def
is_tracing_enabled
(
self
)
->
bool
:
...
...
vllm/v1/engine/input_processor.py
View file @
a358e4df
...
@@ -29,7 +29,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
...
@@ -29,7 +29,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from
vllm.multimodal.processing.context
import
set_request_id
from
vllm.multimodal.processing.context
import
set_request_id
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
Renderer
Like
from
vllm.renderers
import
Base
Renderer
from
vllm.sampling_params
import
_SAMPLING_EPS
,
SamplingParams
from
vllm.sampling_params
import
_SAMPLING_EPS
,
SamplingParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
...
@@ -96,7 +96,7 @@ class InputProcessor:
...
@@ -96,7 +96,7 @@ class InputProcessor:
return
self
.
input_preprocessor
.
get_tokenizer
()
return
self
.
input_preprocessor
.
get_tokenizer
()
@
property
@
property
def
renderer
(
self
)
->
Renderer
Like
:
def
renderer
(
self
)
->
Base
Renderer
:
return
self
.
input_preprocessor
.
renderer
return
self
.
input_preprocessor
.
renderer
def
_validate_logprobs
(
def
_validate_logprobs
(
...
...
vllm/v1/engine/llm_engine.py
View file @
a358e4df
...
@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...
@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.plugins.io_processors
import
get_io_processor
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.renderers
import
Renderer
Like
from
vllm.renderers
import
Base
Renderer
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.tasks
import
SupportedTask
from
vllm.tasks
import
SupportedTask
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
...
@@ -367,7 +367,7 @@ class LLMEngine:
...
@@ -367,7 +367,7 @@ class LLMEngine:
return
self
.
input_processor
.
get_tokenizer
()
return
self
.
input_processor
.
get_tokenizer
()
@
property
@
property
def
renderer
(
self
)
->
Renderer
Like
:
def
renderer
(
self
)
->
Base
Renderer
:
return
self
.
input_processor
.
renderer
return
self
.
input_processor
.
renderer
def
do_log_stats
(
self
)
->
None
:
def
do_log_stats
(
self
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment