Unverified Commit af2a0c16 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[JAX] Async issuing D2H memcpy for grouped_gemm group_sizes array (#2213)



* Try async copy of grouped GEMM group_sizes data
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 66f9b3cb
......@@ -1366,14 +1366,22 @@ class TestGroupedDense:
lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input(
dtype, input_shape, layout
)
num_gemms = input_shape[0]
_ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))(
group_sizes,
num_gemms=num_gemms,
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)
# jitting grouped_gemm
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs,
rhs,
group_sizes,
contracting_dims,
use_async_d2h_group_sizes=True,
)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)
......
......@@ -58,6 +58,7 @@ __all__ = [
"collective_gemm_bootstrap",
"noop_collective_op_set",
"gemm",
"grouped_gemm_copy_group_sizes",
"grouped_gemm",
"gemm_uses_jax_dot",
"sanitize_dims",
......@@ -1237,6 +1238,63 @@ def _te_gemm(
)
class GroupedGemmCopySizesPrimitive(BasePrimitive):
"""
Primitive for async copying group sizes from device to host
"""
name = "te_grouped_gemm_d2h_group_sizes_ffi"
multiple_results = False
impl_static_args = (1,)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
group_sizes_aval,
*,
num_gemms,
):
del num_gemms
out_aval = group_sizes_aval
return out_aval
@staticmethod
def outer_abstract(*args, **kwargs):
out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs)
return out
@staticmethod
def lowering(
ctx,
group_sizes,
num_gemms,
):
return jax.ffi.ffi_lowering(
GroupedGemmCopySizesPrimitive.name,
operand_output_aliases={0: 0}, # Mark num_gemms as the output
)(
ctx,
group_sizes,
num_gemms=num_gemms,
)
@staticmethod
def impl(
group_sizes,
num_gemms,
):
assert GroupedGemmCopySizesPrimitive.inner_primitive is not None
out = GroupedGemmCopySizesPrimitive.inner_primitive.bind(
group_sizes,
num_gemms=num_gemms,
)
return out
register_primitive(GroupedGemmCopySizesPrimitive)
class GroupedGemmPrimitive(BasePrimitive):
"""
Primitive for grouped GEMM
......@@ -1244,7 +1302,7 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15)
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
inner_primitive = None
outer_primitive = None
......@@ -1267,6 +1325,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype,
has_bias,
is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
):
"""
Grouped GEMM operation.
......@@ -1294,7 +1353,7 @@ class GroupedGemmPrimitive(BasePrimitive):
A jnp.ndarray containing the result of the grouped GEMM operation
"""
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del K, lhs_is_trans, rhs_is_trans, has_bias
del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_alignment_padding = 256
......@@ -1341,6 +1400,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype,
has_bias,
is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
):
del out_dtype
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
......@@ -1354,6 +1414,7 @@ class GroupedGemmPrimitive(BasePrimitive):
scaling_mode=scaling_mode.value,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
)
@staticmethod
......@@ -1374,6 +1435,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype,
has_bias,
is_grouped_dense_wgrad,
use_async_d2h_group_sizes,
):
assert GroupedGemmPrimitive.inner_primitive is not None
(out, _) = GroupedGemmPrimitive.inner_primitive.bind(
......@@ -1393,6 +1455,7 @@ class GroupedGemmPrimitive(BasePrimitive):
out_dtype=out_dtype,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
)
return (out,)
......@@ -1661,6 +1724,24 @@ def gemm(
return clean_outputs
def grouped_gemm_copy_group_sizes(
group_sizes: jnp.ndarray,
num_gemms: int,
) -> jnp.ndarray:
"""
Async copy group sizes from device to host
Args:
group_sizes: 1D array containing the sizes of each group
num_gemms: number of grouped gemm calls to be made
"""
out = GroupedGemmCopySizesPrimitive.outer_primitive.bind(
group_sizes,
num_gemms=num_gemms,
)
return out
def grouped_gemm(
lhs: Union[jnp.ndarray, GroupedScaledTensor1x],
rhs: Union[jnp.ndarray, GroupedScaledTensor1x],
......@@ -1671,6 +1752,7 @@ def grouped_gemm(
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
use_async_d2h_group_sizes: bool = False,
) -> jnp.ndarray:
"""
Grouped GEMM operation.
......@@ -1854,5 +1936,6 @@ def grouped_gemm(
out_dtype=out_dtype,
has_bias=has_bias,
is_grouped_dense_wgrad=is_grouped_dense_wgrad,
use_async_d2h_group_sizes=use_async_d2h_group_sizes,
)
return out
......@@ -135,6 +135,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
// Cudnn helpers
......
......@@ -284,12 +284,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<JAXX_Collective_Op>("collective_op"),
FFI_CudaGraph_Traits);
size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes,
int32_t *host_group_sizes) {
static std::once_flag init_flag;
static cudaEvent_t d2h_event;
static size_t host_num_gemms;
static const size_t max_num_gemms = 1024;
//static int32_t host_group_sizes_internal[max_num_gemms];
static int32_t *host_group_sizes_internal = nullptr;
auto init = [&]() {
NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event));
NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms));
};
std::call_once(init_flag, init);
NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr,
"Only one of dev_group_sizes and host_group_sizes can be non-nullptr.");
if (dev_group_sizes != nullptr) {
NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ",
"supported number ", max_num_gemms, " to be downloaded in advance.");
host_num_gemms = num_gemms;
// Wait for current compute stream to finish
cudaStream_t compute_stream_0 = nvte_get_compute_stream(0);
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event));
// Async copy group_sizes from device to host
size_t copy_bytes = sizeof(int32_t) * num_gemms;
NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes,
cudaMemcpyDeviceToHost, compute_stream_0));
NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0));
return num_gemms;
}
if (host_group_sizes != nullptr) {
if (host_num_gemms == 0) return 0;
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the previous value ", host_num_gemms, ".");
// Wait for the async copy to finish, then copy group_sizes to user buffer
// Note: This may break cudaGraph.
NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event));
memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms);
return host_num_gemms;
}
}
Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes,
Result_Type dummy_output, size_t num_gemms) {
int32_t *dev_group_sizes = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // group_sizes
.Ret<Buffer_Type>() // dummy_output
.Attr<int64_t>("num_gemms"));
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans,
bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias,
bool is_grouped_dense_wgrad) {
bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) {
// Notes on matrix layouts and transpose:
// Jax uses row-major data_layout, on entering this function, each input matrix pair:
// A: row-major [m, k] for N - [k, m] for T
......@@ -410,11 +469,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t dim_list_bytes = sizeof(int32_t) * num_gemms;
std::vector<int32_t> dim_list_host(num_gemms);
size_t host_num_gemms = 0;
if (use_async_d2h_group_sizes) {
host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data());
NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms,
" does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, ".");
} else {
auto dim_list_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
}
size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
if (!is_grouped_dense_wgrad) {
NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
......@@ -673,7 +739,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<bool>("rhs_is_trans")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"));
.Attr<bool>("is_grouped_dense_wgrad")
.Attr<bool>("use_async_d2h_group_sizes"));
} // namespace jax
} // namespace transformer_engine
......@@ -69,6 +69,9 @@ pybind11::dict Registrations() {
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM
dict["te_grouped_gemm_d2h_group_sizes_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment