test_common_inference_params.py 425 Bytes
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
9
10
11
from megatron.core.inference.common_inference_params import CommonInferenceParams


class TestCommonInferenceParams:

    def test_inference_params(self):
        inference_parameters = CommonInferenceParams()
        inference_parameters.add_attributes({"min_tokens": 45})
        assert (
            inference_parameters.min_tokens == 45
        ), f"min tokens not set correctly. it is {inference_parameters.min_tokens}"