Unverified Commit a128eec9 authored by Zihao's avatar Zihao Committed by GitHub
Browse files

register aten._convolution.default (#2137)

parent ee287620
...@@ -163,6 +163,23 @@ def meta_conv( ...@@ -163,6 +163,23 @@ def meta_conv(
return out return out
@register_meta(aten._convolution.default)
def meta_conv_1(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default) @register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask): padding, dilation, transposed, output_padding, groups, output_mask):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment