Unverified Commit af2a0c16 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Async issuing D2H memcpy for grouped_gemm group_sizes array (#2213)



* Try async copy of grouped GEMM group_sizes data
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 66f9b3cb
...@@ -1366,14 +1366,22 @@ class TestGroupedDense: ...@@ -1366,14 +1366,22 @@ class TestGroupedDense:
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, input_shape, layout dtype, input_shape, layout
) )
num_gemms = input_shape[0]
_ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))(
group_sizes,
num_gemms=num_gemms,
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm # jitting grouped_gemm
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs, lhs,
rhs, rhs,
group_sizes, group_sizes,
contracting_dims, contracting_dims,
use_async_d2h_group_sizes=True,
) )
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
......
...@@ -58,6 +58,7 @@ __all__ = [ ...@@ -58,6 +58,7 @@ __all__ = [
"collective_gemm_bootstrap", "collective_gemm_bootstrap",
"noop_collective_op_set", "noop_collective_op_set",
"gemm", "gemm",
"grouped_gemm_copy_group_sizes",
"grouped_gemm", "grouped_gemm",
"gemm_uses_jax_dot", "gemm_uses_jax_dot",
"sanitize_dims", "sanitize_dims",
...@@ -1237,6 +1238,63 @@ def _te_gemm( ...@@ -1237,6 +1238,63 @@ def _te_gemm(
) )
class GroupedGemmCopySizesPrimitive(BasePrimitive):
"""
Primitive for async copying group sizes from device to host
"""
name = "te_grouped_gemm_d2h_group_sizes_ffi"
multiple_results = False
impl_static_args = (1,)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
group_sizes_aval,
*,
num_gemms,
):
del num_gemms
out_aval = group_sizes_aval
return out_aval
@staticmethod
def outer_abstract(*args, **kwargs):
out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs)
return out
@staticmethod
def lowering(
ctx,
group_sizes,
num_gemms,
):
return jax.ffi.ffi_lowering(
GroupedGemmCopySizesPrimitive.name,
operand_output_aliases={0: 0}, # Mark num_gemms as the output
)(
ctx,
group_sizes,
num_gemms=num_gemms,
)
@staticmethod
def impl(
group_sizes,
num_gemms,
):
assert GroupedGemmCopySizesPrimitive.inner_primitive is not None
out = GroupedGemmCopySizesPrimitive.inner_primitive.bind(
group_sizes,
num_gemms=num_gemms,
)
return out
register_primitive(GroupedGemmCopySizesPrimitive)
class GroupedGemmPrimitive(BasePrimitive): class GroupedGemmPrimitive(BasePrimitive):
""" """
Primitive for grouped GEMM Primitive for grouped GEMM
...@@ -1244,7 +1302,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -1244,7 +1302,7 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi" name = "te_grouped_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -1267,6 +1325,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -1267,6 +1325,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype, out_dtype,
has_bias, has_bias,
is_grouped_dense_wgrad, is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
): ):
""" """
Grouped GEMM operation. Grouped GEMM operation.
...@@ -1294,7 +1353,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -1294,7 +1353,7 @@ class GroupedGemmPrimitive(BasePrimitive):
A jnp.ndarray containing the result of the grouped GEMM operation A jnp.ndarray containing the result of the grouped GEMM operation
""" """
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, has_bias del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes
# TODO(Phuong): move some shape checks from Cpp to here # TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_alignment_padding = 256 workspace_alignment_padding = 256
...@@ -1341,6 +1400,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -1341,6 +1400,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype, out_dtype,
has_bias, has_bias,
is_grouped_dense_wgrad, is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
): ):
del out_dtype del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
...@@ -1354,6 +1414,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -1354,6 +1414,7 @@ class GroupedGemmPrimitive(BasePrimitive):
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
has_bias=has_bias, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad, is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
) )
@staticmethod @staticmethod
...@@ -1374,6 +1435,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -1374,6 +1435,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype, out_dtype,
has_bias, has_bias,
is_grouped_dense_wgrad, is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
): ):
assert GroupedGemmPrimitive.inner_primitive is not None assert GroupedGemmPrimitive.inner_primitive is not None
(out, _) = GroupedGemmPrimitive.inner_primitive.bind( (out, _) = GroupedGemmPrimitive.inner_primitive.bind(
...@@ -1393,6 +1455,7 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -1393,6 +1455,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype=out_dtype, out_dtype=out_dtype,
has_bias=has_bias, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad, is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
) )
return (out,) return (out,)
...@@ -1661,6 +1724,24 @@ def gemm( ...@@ -1661,6 +1724,24 @@ def gemm(
return clean_outputs return clean_outputs
def grouped_gemm_copy_group_sizes(
group_sizes: jnp.ndarray,
num_gemms: int,
) -> jnp.ndarray:
"""
Async copy group sizes from device to host
Args:
group_sizes: 1D array containing the sizes of each group
num_gemms: number of grouped gemm calls to be made
"""
out = GroupedGemmCopySizesPrimitive.outer_primitive.bind(
group_sizes,
num_gemms=num_gemms,
)
return out
def grouped_gemm( def grouped_gemm(
lhs: Union[jnp.ndarray, GroupedScaledTensor1x], lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
rhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
...@@ -1671,6 +1752,7 @@ def grouped_gemm( ...@@ -1671,6 +1752,7 @@ def grouped_gemm(
preferred_element_type: jnp.dtype = None, preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None, group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
use_async_d2h_group_sizes: bool = False,
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
Grouped GEMM operation. Grouped GEMM operation.
...@@ -1854,5 +1936,6 @@ def grouped_gemm( ...@@ -1854,5 +1936,6 @@ def grouped_gemm(
out_dtype=out_dtype, out_dtype=out_dtype,
has_bias=has_bias, has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad, is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
) )
return out return out
...@@ -135,6 +135,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); ...@@ -135,6 +135,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
// Grouped GEMM // Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
// Cudnn helpers // Cudnn helpers
......
...@@ -284,12 +284,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, ...@@ -284,12 +284,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<JAXX_Collective_Op>("collective_op"), .Attr<JAXX_Collective_Op>("collective_op"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes,
int32_t *host_group_sizes) {
static std::once_flag init_flag;
static cudaEvent_t d2h_event;
static size_t host_num_gemms;
static const size_t max_num_gemms = 1024;
//static int32_t host_group_sizes_internal[max_num_gemms];
static int32_t *host_group_sizes_internal = nullptr;
auto init = [&]() {
NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event));
NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms));
};
std::call_once(init_flag, init);
NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr,
"Only one of dev_group_sizes and host_group_sizes can be non-nullptr.");
if (dev_group_sizes != nullptr) {
NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ",
"supported number ", max_num_gemms, " to be downloaded in advance.");
host_num_gemms = num_gemms;
// Wait for current compute stream to finish
cudaStream_t compute_stream_0 = nvte_get_compute_stream(0);
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event));
// Async copy group_sizes from device to host
size_t copy_bytes = sizeof(int32_t) * num_gemms;
NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes,
cudaMemcpyDeviceToHost, compute_stream_0));
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0));
return num_gemms;
}
if (host_group_sizes != nullptr) {
if (host_num_gemms == 0) return 0;
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the previous value ", host_num_gemms, ".");
// Wait for the async copy to finish, then copy group_sizes to user buffer
// Note: This may break cudaGraph.
NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event));
memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms);
return host_num_gemms;
}
}
Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes,
Result_Type dummy_output, size_t num_gemms) {
int32_t *dev_group_sizes = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // group_sizes
.Ret<Buffer_Type>() // dummy_output
.Attr<int64_t>("num_gemms"));
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) { bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) {
// Notes on matrix layouts and transpose: // Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair: // Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major [m, k] for N - [k, m] for T // A: row-major [m, k] for N - [k, m] for T
...@@ -410,11 +469,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -410,11 +469,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t dim_list_bytes = sizeof(int32_t) * num_gemms; size_t dim_list_bytes = sizeof(int32_t) * num_gemms;
std::vector<int32_t> dim_list_host(num_gemms); std::vector<int32_t> dim_list_host(num_gemms);
size_t host_num_gemms = 0;
if (use_async_d2h_group_sizes) {
host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data());
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, ".");
} else {
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data()); auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream); stream);
// Note: This may break cudaGraph. // Note: This may break cudaGraph.
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
}
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) { if (!is_grouped_dense_wgrad) {
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
...@@ -673,7 +739,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, ...@@ -673,7 +739,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("rhs_is_trans") .Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode") .Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias") .Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad")); .Attr<bool>("is_grouped_dense_wgrad")
.Attr<bool>("use_async_d2h_group_sizes"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -69,6 +69,9 @@ pybind11::dict Registrations() { ...@@ -69,6 +69,9 @@ pybind11::dict Registrations() {
pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM // Grouped GEMM
dict["te_grouped_gemm_d2h_group_sizes_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
dict["te_grouped_gemm_ffi"] = dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment