"transformer_engine/pytorch/attention.py" did not exist on "0c9c0ba1fe0f5f43b4bb68a690b9d8832496216b"
Unverified Commit aaf93548 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Allow multi-dims for dgamma and dbeta in LN descriptor. (#780)



* Allow multi-dims for dgamma and dbeta in LN descriptor.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the jit error in examples/jax
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 2a0fe783
......@@ -55,7 +55,7 @@ class Net(nn.Module):
return x
@partial(jax.jit, static_argnums=(0, 1, 2, 3, 4, 5))
@partial(jax.jit)
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""
......
......@@ -74,7 +74,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
return grads, loss, accuracy
@partial(jax.jit, static_argnums=(0, 1))
@partial(jax.jit)
def update_model(state, grads):
"""Update model params and FP8 meta."""
state = state.apply_gradients(grads=grads[PARAMS_KEY])
......
......@@ -385,8 +385,8 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
......@@ -464,7 +464,6 @@ class LayerNormFwdPrimitive(BasePrimitive):
f"Enforcing no sharding of parameters hidden dim! " \
)
x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
g_sharding = NamedSharding(mesh, PartitionSpec(None))
b_sharding = NamedSharding(mesh, PartitionSpec(None))
......@@ -589,8 +588,8 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
dbeta_part_aval.size,
dgamma_part_aval.shape,
dbeta_part_aval.shape,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
......@@ -791,8 +790,8 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
......@@ -968,8 +967,8 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.size,
0, # no dbeta_part for RMSnorm
dgamma_part_aval.shape,
(0,), # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
......@@ -3588,8 +3587,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
......@@ -3840,8 +3839,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
hidden_size,
wkspace_aval.size,
barrier_aval.size,
0, # no dgamma_part in FWD pass
0, # no dbeta_part in BWD pass
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
......
......@@ -71,17 +71,28 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shap
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin) {
return PackOpaque(CustomCallNormDescriptor{
batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes, dbeta_part_sizes,
x_dtype, w_dtype, wkspace_dtype, barrier_dtype, dgamma_part_dtype, dbeta_part_dtype,
zero_centered_gamma, eps, sm_margin});
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) {
CustomCallNormDescriptor desc;
desc.batch_size = batch_size;
desc.hidden_size = hidden_size;
desc.wkspace_size = wkspace_size;
desc.barrier_size = barrier_size;
desc.dgamma_part_shape.from_vector(dgamma_part_shape);
desc.dbeta_part_shape.from_vector(dbeta_part_shape);
desc.x_dtype = x_dtype;
desc.w_dtype = w_dtype;
desc.wkspace_dtype = wkspace_dtype;
desc.barrier_dtype = barrier_dtype;
desc.dgamma_part_dtype = dgamma_part_dtype;
desc.dbeta_part_dtype = dbeta_part_dtype;
desc.zero_centered_gamma = zero_centered_gamma;
desc.eps = eps;
desc.sm_margin = sm_margin;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
......@@ -529,7 +540,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
}
void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size,
size_t barrier_size, size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape,
bool zero_centered_gamma, float eps, void *input, DType in_dtype,
void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
......@@ -563,14 +574,14 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype);
auto barrier_shape = std::vector<size_t>{barrier_size};
auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype);
auto dgamma_part_shape = std::vector<size_t>{dgamma_part_sizes[0], dgamma_part_sizes[1]};
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape, dgamma_dtype);
auto dgamma_part_tensor =
TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype);
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_shape = std::vector<size_t>{dbeta_part_sizes[0], dbeta_part_sizes[1]};
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape, dbeta_dtype);
auto dbeta_part_tensor =
TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype);
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
......@@ -664,8 +675,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto *dgamma_part_sizes = desc.dgamma_part_sizes;
auto *dbeta_part_sizes = desc.dbeta_part_sizes;
auto dgamma_part_shape = desc.dgamma_part_shape;
auto dbeta_part_shape = desc.dbeta_part_shape;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
......@@ -689,8 +700,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *dgamma_part = buffers[10];
auto *dbeta_part = buffers[11];
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
......@@ -786,8 +797,9 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto hidden_size = desc.hidden_size;
auto wkspace_size = desc.wkspace_size;
auto barrier_size = desc.barrier_size;
auto dgamma_part_sizes = desc.dgamma_part_sizes;
size_t dbeta_part_sizes[2] = {0, 0};
auto dgamma_part_shape = desc.dgamma_part_shape;
Shape dbeta_part_shape;
dbeta_part_shape.from_vector({0, 0});
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto wkspace_dtype = desc.wkspace_dtype;
......@@ -797,8 +809,8 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_sizes,
dbeta_part_sizes, zero_centered_gamma, eps, input, in_dtype, weight,
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream);
......
......@@ -69,8 +69,8 @@ struct CustomCallNormDescriptor {
size_t hidden_size;
size_t wkspace_size;
size_t barrier_size;
size_t *dgamma_part_sizes; // 2D tensor
size_t *dbeta_part_sizes; // 2D tensor
Shape dgamma_part_shape;
Shape dbeta_part_shape;
DType x_dtype;
DType w_dtype;
DType wkspace_dtype;
......@@ -82,13 +82,11 @@ struct CustomCallNormDescriptor {
int sm_margin;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
size_t wkspace_size, size_t barrier_size,
size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
DType x_dtype, DType w_dtype, DType wkspace_dtype,
DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma,
float eps, int sm_margin);
pybind11::bytes PackCustomCallNormDescriptor(
size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment