Unverified Commit 5afbb0e1 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Make softmax_type in FFI optional (#2491)



* Make softmax_type in FFI optional
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add warn message
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 46c6ef31
...@@ -353,7 +353,8 @@ static void FusedAttnForwardImpl( ...@@ -353,7 +353,8 @@ static void FusedAttnForwardImpl(
NVTE_Mask_Type mask_type = \ NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \ static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_Softmax_Type softmax_type = \ NVTE_Softmax_Type softmax_type = \
static_cast<NVTE_Softmax_Type>(get_attr_value<int64_t>(attrs, "softmax_type")); \ static_cast<NVTE_Softmax_Type>(get_attr_value_or_default<int64_t>( \
attrs, "softmax_type", static_cast<int64_t>(NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))); \
NVTE_QKV_Layout qkv_layout = \ NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \ static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \ bool is_training = get_attr_value<bool>(attrs, "is_training"); \
......
...@@ -75,6 +75,21 @@ T get_attr_value(Dictionary& attrs, std::string attr_name, ...@@ -75,6 +75,21 @@ T get_attr_value(Dictionary& attrs, std::string attr_name,
return attr.value(); return attr.value();
} }
template <typename T>
T get_attr_value_or_default(Dictionary& attrs, std::string attr_name, T default_value,
const source_location& loc = source_location::current()) {
auto attr = attrs.get<T>(attr_name);
if (attr.has_error()) {
NVTE_WARN("Failure in getting attribute value of '", attr_name, "'\n",
"Called from: ", loc.file_name(), ":", loc.line(), "\n",
"In function: ", loc.function_name(), "\n",
"Please ensure the attribute name and datatype match between C++ and Python APIs. "
"Currently falling back to a default value.");
return default_value;
}
return attr.value();
}
inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_idx = 0, inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_idx = 0,
size_t end_idx = 0) { size_t end_idx = 0) {
end_idx = (end_idx == 0) ? data.size() : end_idx; end_idx = (end_idx == 0) ? data.size() : end_idx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment