# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
ifresidualisnotNone:
hidden_states+=residual
residual=hidden_states
out=torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
returnout,residual
else:
raiseValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
ifresidualisnotNone:
hidden_states+=residual
residual=hidden_states
out=torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
returnout,residual
else:
raiseValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
ifresidualisnotNone:
hidden_states+=residual
residual=hidden_states
out=torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
returnout,residual
else:
raiseValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
raiseValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
raiseValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
@classmethod
defstatic(cls,config,dim,base,device):
...
...
@@ -713,9 +793,9 @@ try:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
inv_freq_mask=(1-linear_ramp_mask(low,high,self.dim//2).float().to(device))*self.extrapolation_factor# Get n-d rotational scaling corrected for extrapolation