Commit 1ec09ebd authored by Tri Dao's avatar Tri Dao
Browse files

[FusedDense] Limit matrix dims to 2M (instead of 64k)

parent 714c1b4f
...@@ -46,9 +46,11 @@ class FusedDenseFunc(torch.autograd.Function): ...@@ -46,9 +46,11 @@ class FusedDenseFunc(torch.autograd.Function):
weight = weight.contiguous() weight = weight.contiguous()
if process_group is not None: if process_group is not None:
handle_x.wait() handle_x.wait()
batch_shape = total_x.shape[:-1] batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
output = F.linear(total_x, weight, bias) output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient: if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight) ctx.save_for_backward(x, weight)
...@@ -105,11 +107,9 @@ class FusedDenseFunc(torch.autograd.Function): ...@@ -105,11 +107,9 @@ class FusedDenseFunc(torch.autograd.Function):
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False, process_group: Optional[ProcessGroup] = None): return_residual: bool = False, process_group: Optional[ProcessGroup] = None):
batch_dim = x.shape[:-1].numel()
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled())) or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024 if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
and dtype_eligible):
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group) return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
else: else:
assert process_group is None assert process_group is None
...@@ -222,7 +222,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ...@@ -222,7 +222,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
handle_x.wait() handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel() batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
if heuristic == -1: if heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1) gelu_in = F.linear(total_x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh') output1 = F.gelu(gelu_in, approximate='tanh')
...@@ -348,12 +350,10 @@ def fused_dense_gelu_dense_func( ...@@ -348,12 +350,10 @@ def fused_dense_gelu_dense_func(
checkpoint_lvl: int = 0, heuristic: int = 0, checkpoint_lvl: int = 0, heuristic: int = 0,
process_group: Optional[ProcessGroup] = None process_group: Optional[ProcessGroup] = None
): ):
batch_dim = x.shape[:-1].numel()
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled())) or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024 and (bias2 is None or bias2.is_cuda) and dtype_eligible):
and dtype_eligible):
return FusedDenseGeluDenseFunc.apply( return FusedDenseGeluDenseFunc.apply(
x, weight1, bias1, weight2, bias2, x, weight1, bias1, weight2, bias2,
save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment