• Pedro Cuenca's avatar
    Flax safety checker (#825) · 78db11db
    Pedro Cuenca authored
    
    
    * Remove set_format in Flax pipeline.
    
    * Remove DummyChecker.
    
    * Run safety_checker in pipeline.
    
    * Don't pmap on every call.
    
    We could have decorated `generate` with `pmap`, but I wanted to keep it
    in case someone wants to invoke it in non-parallel mode.
    
    * Remove commented line
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    
    * Replicate outside __call__, prepare for optional jitting.
    
    * Remove unnecessary clipping.
    
    As suggested by @kashif.
    
    * Do not jit unless requested.
    
    * Send all args to generate.
    
    * make style
    
    * Remove unused imports.
    
    * Fix docstring.
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    78db11db
pipeline_flax_utils.py 20.4 KB