Commit 1545db53 authored by rusty1s's avatar rusty1s
Browse files

prints

parent 7febeb31
...@@ -118,6 +118,8 @@ def test_forward(test, dtype, device): ...@@ -118,6 +118,8 @@ def test_forward(test, dtype, device):
src = tensor(test['src'], dtype, device) src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
expected = tensor(test['expected'], dtype, device) expected = tensor(test['expected'], dtype, device)
print(src)
print(index)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name'])) op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out = op(src, index, test['dim'], fill_value=test['fill_value']) out = op(src, index, test['dim'], fill_value=test['fill_value'])
......
...@@ -7,6 +7,7 @@ from torch_scatter.utils.gen import gen ...@@ -7,6 +7,7 @@ from torch_scatter.utils.gen import gen
class ScatterMul(Function): class ScatterMul(Function):
@staticmethod @staticmethod
def forward(ctx, out, src, index, dim): def forward(ctx, out, src, index, dim):
print("DRIN")
func = get_func('scatter_mul', src) func = get_func('scatter_mul', src)
func(src, index, out, dim) func(src, index, out, dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment