sampling_params.py 254 Bytes
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
from dataclasses import dataclass


@dataclass
class SamplingParams:
    temperature: float = 1.0
    max_new_tokens: int = 256

    def __post_init__(self):
        if self.temperature < 0:
            raise ValueError("Temperature cannot be negative")