Unverified Commit 3b8d9a8a authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[Pytorch] remove redundant error check in Linear module (#2420)



remove linear redundant check
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
parent 66ae3030
...@@ -1536,25 +1536,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -1536,25 +1536,11 @@ class Linear(TransformerEngineBaseModule):
def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Get concatenated weight and bias tensors # Get concatenated weight and bias tensors
unfused_weights = self._get_weight_tensors() unfused_weights = self._get_weight_tensors()
if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
if self.fp8:
if len(unfused_weights) != 1:
raise RuntimeError(
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
else: else:
bias_tensor = None bias_tensor = None
return weight_tensor, bias_tensor return weight_tensor, bias_tensor
def onnx_forward( def onnx_forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment