Unverified Commit 9e26e354 authored by tigertang's avatar tigertang Committed by GitHub
Browse files

Add aten::detach in infer_shape to support .detach (#3244)

parent 163e064e
...@@ -273,7 +273,8 @@ infer_from_inshape = { ...@@ -273,7 +273,8 @@ infer_from_inshape = {
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape), 'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), 'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask), 'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask) 'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::detach': lambda module_masks, mask: dropout_inshape(module_masks, mask)
} }
""" """
...@@ -308,7 +309,8 @@ infer_from_outshape = { ...@@ -308,7 +309,8 @@ infer_from_outshape = {
'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape), 'aten::mean': lambda module_masks, mask, shape: mean_outshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask), 'Dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask), 'Dropout2d': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask) 'aten::dropout': lambda module_masks, mask: dropout_outshape(module_masks, mask),
'aten::detach': lambda module_masks, mask: dropout_outshape(module_masks, mask)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment