Commit 451b975a authored by mashun1's avatar mashun1
Browse files

Update filtered_lrelu.py

parent 0bf49494
......@@ -110,8 +110,8 @@ def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np
"""
assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda']
if impl == 'cuda' and x.device.type == 'cuda' and _init():
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
# if impl == 'cuda' and x.device.type == 'cuda' and _init():
# return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
#----------------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment