Unverified Commit 5afbb0e1 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Make softmax_type in FFI optional (#2491)



* Make softmax_type in FFI optional
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add warn message
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 46c6ef31
......@@ -333,34 +333,35 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_destroy(&aux_output_tensors);
}
#define FUSED_ATTN_FFI_GET_ATTRS \
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen"); \
size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen"); \
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \
size_t qk_head_dim = get_attr_value<int64_t>(attrs, "qk_head_dim"); \
size_t v_head_dim = get_attr_value<int64_t>(attrs, "v_head_dim"); \
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \
NVTE_Bias_Type bias_type = \
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_Softmax_Type softmax_type = \
static_cast<NVTE_Softmax_Type>(get_attr_value<int64_t>(attrs, "softmax_type")); \
NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \
bool deterministic = get_attr_value<bool>(attrs, "deterministic"); \
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \
size_t wkspace_size = product(workspace_buf->dimensions()); \
DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); \
#define FUSED_ATTN_FFI_GET_ATTRS \
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen"); \
size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen"); \
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \
size_t qk_head_dim = get_attr_value<int64_t>(attrs, "qk_head_dim"); \
size_t v_head_dim = get_attr_value<int64_t>(attrs, "v_head_dim"); \
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \
NVTE_Bias_Type bias_type = \
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_Softmax_Type softmax_type = \
static_cast<NVTE_Softmax_Type>(get_attr_value_or_default<int64_t>( \
attrs, "softmax_type", static_cast<int64_t>(NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))); \
NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \
bool deterministic = get_attr_value<bool>(attrs, "deterministic"); \
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \
size_t wkspace_size = product(workspace_buf->dimensions()); \
DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); \
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf,
......
......@@ -75,6 +75,21 @@ T get_attr_value(Dictionary& attrs, std::string attr_name,
return attr.value();
}
template <typename T>
T get_attr_value_or_default(Dictionary& attrs, std::string attr_name, T default_value,
const source_location& loc = source_location::current()) {
auto attr = attrs.get<T>(attr_name);
if (attr.has_error()) {
NVTE_WARN("Failure in getting attribute value of '", attr_name, "'\n",
"Called from: ", loc.file_name(), ":", loc.line(), "\n",
"In function: ", loc.function_name(), "\n",
"Please ensure the attribute name and datatype match between C++ and Python APIs. "
"Currently falling back to a default value.");
return default_value;
}
return attr.value();
}
inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_idx = 0,
size_t end_idx = 0) {
end_idx = (end_idx == 0) ? data.size() : end_idx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment