Unverified Commit 326a1b00 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve documentation of `ModelConfig.try_get_generation_config` to prevent...


Improve documentation of `ModelConfig.try_get_generation_config` to prevent future confusion (#21526)
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 2d7b09b9
...@@ -1575,7 +1575,18 @@ class ModelConfig: ...@@ -1575,7 +1575,18 @@ class ModelConfig:
return self.multimodal_config return self.multimodal_config
def try_get_generation_config(self) -> dict[str, Any]: def try_get_generation_config(self) -> dict[str, Any]:
if self.generation_config in ("auto", "vllm"): """
This method attempts to retrieve the non-default values of the
generation config for this model.
The generation config can contain information about special tokens, as
well as sampling parameters. Which is why this method exists separately
to `get_diff_sampling_param`.
Returns:
A dictionary containing the non-default generation config.
"""
if self.generation_config in {"auto", "vllm"}:
config = try_get_generation_config( config = try_get_generation_config(
self.hf_config_path or self.model, self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
...@@ -1594,13 +1605,18 @@ class ModelConfig: ...@@ -1594,13 +1605,18 @@ class ModelConfig:
def get_diff_sampling_param(self) -> dict[str, Any]: def get_diff_sampling_param(self) -> dict[str, Any]:
""" """
This method returns a dictionary containing the parameters This method returns a dictionary containing the non-default sampling
that differ from the default sampling parameters. If parameters with `override_generation_config` applied.
`generation_config` is `"vllm"`, an empty dictionary is returned.
The default sampling parameters are:
- vLLM's neutral defaults if `self.generation_config="vllm"`
- the model's defaults if `self.generation_config="auto"`
- as defined in `generation_config.json` if
`self.generation_config="path/to/generation_config/dir"`
Returns: Returns:
dict[str, Any]: A dictionary with the differing sampling A dictionary containing the non-default sampling parameters.
parameters, if `generation_config` is `"vllm"` an empty dictionary.
""" """
if self.generation_config == "vllm": if self.generation_config == "vllm":
config = {} config = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment