Commit a207db1d authored by yuguo's avatar yuguo
Browse files
parents fbee8990 69365f88
This diff is collapsed.
......@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
enum class QuantizeAxis {
ROWWISE,
COLWISE,
ROWWISE_COLWISE,
};
} // namespace jax
} // namespace transformer_engine
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .module import DenseGeneral, LayerNorm
from .module import LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase
from .module import LayerNormDenseGeneral, LayerNormMLP
from .transformer import extend_logical_axis_rules
from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases
from .transformer import TransformerLayer, TransformerLayerType
......@@ -13,7 +13,6 @@ __all__ = [
"LayerNorm",
"LayerNormDenseGeneral",
"LayerNormMLP",
"TransformerEngineBase",
"extend_logical_axis_rules",
"DotProductAttention",
"MultiHeadAttention",
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment