key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
defdecorator(fn):
returnAutotuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
prune_configs_by,
nearest_power_of_two,
)
returndecorator
defmatmul248_kernel_config_pruner(configs,nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
ifresidualisnotNone:
hidden_states+=residual
residual=hidden_states
out=torch.empty_like(hidden_states)
_custom_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
returnout,residual
else:
raiseValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
)
elifquantize=="marlin":
fromtext_generation_server.layers.marlinimport(
GPTQMarlin24Linear,
GPTQMarlin24Weight,
MarlinLinear,
MarlinWeight,
)
ifisinstance(weight,GPTQMarlin24Weight):
linear=GPTQMarlin24Linear(
weight=weight,
bias=bias,
)
elifisinstance(weight,MarlinWeight):
linear=MarlinLinear(weight=weight,bias=bias)
else:
raiseNotImplementedError(
f"The passed weight is not `marlin` compatible, loader needs to be updated."
)
else:
raiseNotImplementedError(f"Quantization `{quantize}` is not implemented yet.")