Unverified Commit 672b6626 authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

Add FX support for torch.baddbmm andd torch.Tensor.baddbmm (#18363)

parent df28de05
......@@ -305,12 +305,22 @@ def torch_matmul(input, other, *, out=None):
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
raise ValueError("Don't support in-place bmm for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")
def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
return torch_bmm(batch1, batch2)
def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)
def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
......@@ -495,6 +505,8 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul,
torch.bmm: torch_bmm,
torch.baddbmm: torch_baddbmm,
torch.Tensor.baddbmm: torch_tensor_baddbmm,
torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.roll: torch_roll,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment