Unverified Commit aaf93548 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Allow multi-dims for dgamma and dbeta in LN descriptor. (#780)



* Allow multi-dims for dgamma and dbeta in LN descriptor.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the jit error in examples/jax
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 2a0fe783
...@@ -55,7 +55,7 @@ class Net(nn.Module): ...@@ -55,7 +55,7 @@ class Net(nn.Module):
return x return x
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4, 5)) @partial(jax.jit)
def train_step(state, inputs, masks, labels, var_collect, rngs): def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch.""" """Computes gradients, loss and accuracy for a single batch."""
......
...@@ -74,7 +74,7 @@ def apply_model(state, images, labels, var_collect, rngs=None): ...@@ -74,7 +74,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
return grads, loss, accuracy return grads, loss, accuracy
@partial(jax.jit, static_argnums=(0, 1)) @partial(jax.jit)
def update_model(state, grads): def update_model(state, grads):
"""Update model params and FP8 meta.""" """Update model params and FP8 meta."""
state = state.apply_gradients(grads=grads[PARAMS_KEY]) state = state.apply_gradients(grads=grads[PARAMS_KEY])
......
...@@ -385,8 +385,8 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -385,8 +385,8 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass (0,), # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass (0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
...@@ -464,7 +464,6 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -464,7 +464,6 @@ class LayerNormFwdPrimitive(BasePrimitive):
f"Enforcing no sharding of parameters hidden dim! " \ f"Enforcing no sharding of parameters hidden dim! " \
) )
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None)) g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None)) b_sharding = NamedSharding(mesh, PartitionSpec(None))
...@@ -589,8 +588,8 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -589,8 +588,8 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
dgamma_part_aval.size, dgamma_part_aval.shape,
dbeta_part_aval.size, dbeta_part_aval.shape,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
...@@ -791,8 +790,8 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -791,8 +790,8 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass (0,), # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass (0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
...@@ -968,8 +967,8 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -968,8 +967,8 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
dgamma_part_aval.size, dgamma_part_aval.shape,
0, # no dbeta_part for RMSnorm (0,), # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
...@@ -3588,8 +3587,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3588,8 +3587,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass (0,), # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass (0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
...@@ -3840,8 +3839,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3840,8 +3839,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
hidden_size, hidden_size,
wkspace_aval.size, wkspace_aval.size,
barrier_aval.size, barrier_aval.size,
0, # no dgamma_part in FWD pass (0,), # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass (0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
......
...@@ -71,17 +71,28 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap ...@@ -71,17 +71,28 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
return PackOpaque(desc); return PackOpaque(desc);
} }
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, pybind11::bytes PackCustomCallNormDescriptor(
size_t wkspace_size, size_t barrier_size, size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType barrier_dtype, DType dgamma_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
DType dbeta_part_dtype, bool zero_centered_gamma, CustomCallNormDescriptor desc;
float eps, int sm_margin) { desc.batch_size = batch_size;
return PackOpaque(CustomCallNormDescriptor{ desc.hidden_size = hidden_size;
batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes, desc.wkspace_size = wkspace_size;
x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype, desc.barrier_size = barrier_size;
zero_centered_gamma, eps, sm_margin}); desc.dgamma_part_shape.from_vector(dgamma_part_shape);
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.barrier_dtype = barrier_dtype;
desc.dgamma_part_dtype = dgamma_part_dtype;
desc.dbeta_part_dtype = dbeta_part_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
return PackOpaque(desc);
} }
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size, pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
...@@ -529,7 +540,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid ...@@ -529,7 +540,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
} }
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size, void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
bool zero_centered_gamma, float eps, void *input, DType in_dtype, bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *ograd, void *workspace, void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
...@@ -563,14 +574,14 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace ...@@ -563,14 +574,14 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size}; auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_shape = std::vector<size_t>{dgamma_part_sizes[0], dgamma_part_sizes[1]}; auto dgamma_part_tensor =
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype); TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
if (is_layer_norm) { if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_shape = std::vector<size_t>{dbeta_part_sizes[0], dbeta_part_sizes[1]}; auto dbeta_part_tensor =
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype); TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
...@@ -664,8 +675,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -664,8 +675,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto hidden_size = desc.hidden_size; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size; auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size; auto barrier_size = desc.barrier_size;
auto *dgamma_part_sizes = desc.dgamma_part_sizes; auto dgamma_part_shape = desc.dgamma_part_shape;
auto *dbeta_part_sizes = desc.dbeta_part_sizes; auto dbeta_part_shape = desc.dbeta_part_shape;
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype; auto wkspace_dtype = desc.wkspace_dtype;
...@@ -689,8 +700,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -689,8 +700,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *dgamma_part = buffers[10]; auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11]; auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream); dbeta_part_dtype, stream);
...@@ -786,8 +797,9 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -786,8 +797,9 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto hidden_size = desc.hidden_size; auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size; auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size; auto barrier_size = desc.barrier_size;
auto dgamma_part_sizes = desc.dgamma_part_sizes; auto dgamma_part_shape = desc.dgamma_part_shape;
size_t dbeta_part_sizes[2] = {0, 0}; Shape dbeta_part_shape;
dbeta_part_shape.from_vector({0, 0});
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype; auto wkspace_dtype = desc.wkspace_dtype;
...@@ -797,8 +809,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -797,8 +809,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream); dbeta_part_dtype, stream);
......
...@@ -69,8 +69,8 @@ struct CustomCallNormDescriptor { ...@@ -69,8 +69,8 @@ struct CustomCallNormDescriptor {
size_t hidden_size; size_t hidden_size;
size_t wkspace_size; size_t wkspace_size;
size_t barrier_size; size_t barrier_size;
size_t *dgamma_part_sizes; // 2D tensor Shape dgamma_part_shape;
size_t *dbeta_part_sizes; // 2D tensor Shape dbeta_part_shape;
DType x_dtype; DType x_dtype;
DType w_dtype; DType w_dtype;
DType wkspace_dtype; DType wkspace_dtype;
...@@ -82,13 +82,11 @@ struct CustomCallNormDescriptor { ...@@ -82,13 +82,11 @@ struct CustomCallNormDescriptor {
int sm_margin; int sm_margin;
}; };
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, pybind11::bytes PackCustomCallNormDescriptor(
size_t wkspace_size, size_t barrier_size, size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes, const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType barrier_dtype, DType dgamma_part_dtype, DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
struct SoftmaxDescriptor { struct SoftmaxDescriptor {
size_t batch_size; size_t batch_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment